summaryrefslogtreecommitdiff
path: root/src/pkg/net
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/net')
-rw-r--r--src/pkg/net/Makefile66
-rw-r--r--src/pkg/net/cgo_stub.go12
-rw-r--r--src/pkg/net/cgo_unix.go15
-rw-r--r--src/pkg/net/dial.go249
-rw-r--r--src/pkg/net/dial_test.go88
-rw-r--r--src/pkg/net/dialgoogle_test.go2
-rw-r--r--src/pkg/net/dict/Makefile7
-rw-r--r--src/pkg/net/dict/dict.go211
-rw-r--r--src/pkg/net/dnsclient.go23
-rw-r--r--src/pkg/net/dnsclient_unix.go39
-rw-r--r--src/pkg/net/dnsconfig.go14
-rw-r--r--src/pkg/net/doc.go59
-rw-r--r--src/pkg/net/fd.go317
-rw-r--r--src/pkg/net/fd_darwin.go29
-rw-r--r--src/pkg/net/fd_freebsd.go20
-rw-r--r--src/pkg/net/fd_linux.go40
-rw-r--r--src/pkg/net/fd_netbsd.go116
-rw-r--r--src/pkg/net/fd_openbsd.go20
-rw-r--r--src/pkg/net/fd_windows.go176
-rw-r--r--src/pkg/net/file.go20
-rw-r--r--src/pkg/net/file_plan9.go6
-rw-r--r--src/pkg/net/file_test.go22
-rw-r--r--src/pkg/net/file_windows.go6
-rw-r--r--src/pkg/net/hosts.go12
-rw-r--r--src/pkg/net/http/Makefile25
-rw-r--r--src/pkg/net/http/cgi/Makefile12
-rw-r--r--src/pkg/net/http/cgi/child.go192
-rw-r--r--src/pkg/net/http/cgi/child_test.go87
-rw-r--r--src/pkg/net/http/cgi/host.go350
-rw-r--r--src/pkg/net/http/cgi/host_test.go477
-rw-r--r--src/pkg/net/http/cgi/matryoshka_test.go74
-rwxr-xr-xsrc/pkg/net/http/cgi/testdata/test.cgi96
-rw-r--r--src/pkg/net/http/chunked.go170
-rw-r--r--src/pkg/net/http/chunked_test.go39
-rw-r--r--src/pkg/net/http/client.go326
-rw-r--r--src/pkg/net/http/client_test.go442
-rw-r--r--src/pkg/net/http/cookie.go267
-rw-r--r--src/pkg/net/http/cookie_test.go200
-rw-r--r--src/pkg/net/http/doc.go80
-rw-r--r--src/pkg/net/http/export_test.go43
-rw-r--r--src/pkg/net/http/fcgi/Makefile12
-rw-r--r--src/pkg/net/http/fcgi/child.go271
-rw-r--r--src/pkg/net/http/fcgi/fcgi.go274
-rw-r--r--src/pkg/net/http/fcgi/fcgi_test.go150
-rw-r--r--src/pkg/net/http/filetransport.go123
-rw-r--r--src/pkg/net/http/filetransport_test.go65
-rw-r--r--src/pkg/net/http/fs.go330
-rw-r--r--src/pkg/net/http/fs_test.go334
-rw-r--r--src/pkg/net/http/header.go78
-rw-r--r--src/pkg/net/http/header_test.go81
-rw-r--r--src/pkg/net/http/httptest/Makefile12
-rw-r--r--src/pkg/net/http/httptest/recorder.go58
-rw-r--r--src/pkg/net/http/httptest/server.go173
-rw-r--r--src/pkg/net/http/httputil/Makefile14
-rw-r--r--src/pkg/net/http/httputil/chunked.go172
-rw-r--r--src/pkg/net/http/httputil/chunked_test.go41
-rw-r--r--src/pkg/net/http/httputil/dump.go196
-rw-r--r--src/pkg/net/http/httputil/dump_test.go140
-rw-r--r--src/pkg/net/http/httputil/persist.go422
-rw-r--r--src/pkg/net/http/httputil/reverseproxy.go167
-rw-r--r--src/pkg/net/http/httputil/reverseproxy_test.go71
-rw-r--r--src/pkg/net/http/jar.go30
-rw-r--r--src/pkg/net/http/lex.go144
-rw-r--r--src/pkg/net/http/lex_test.go70
-rw-r--r--src/pkg/net/http/pprof/Makefile11
-rw-r--r--src/pkg/net/http/pprof/pprof.go133
-rw-r--r--src/pkg/net/http/proxy_test.go48
-rw-r--r--src/pkg/net/http/range_test.go57
-rw-r--r--src/pkg/net/http/readrequest_test.go283
-rw-r--r--src/pkg/net/http/request.go741
-rw-r--r--src/pkg/net/http/request_test.go283
-rw-r--r--src/pkg/net/http/requestwrite_test.go438
-rw-r--r--src/pkg/net/http/response.go236
-rw-r--r--src/pkg/net/http/response_test.go448
-rw-r--r--src/pkg/net/http/responsewrite_test.go109
-rw-r--r--src/pkg/net/http/serve_test.go1182
-rw-r--r--src/pkg/net/http/server.go1191
-rw-r--r--src/pkg/net/http/sniff.go214
-rw-r--r--src/pkg/net/http/sniff_test.go137
-rw-r--r--src/pkg/net/http/status.go106
-rw-r--r--src/pkg/net/http/testdata/file1
-rw-r--r--src/pkg/net/http/testdata/index.html1
-rw-r--r--src/pkg/net/http/testdata/style.css1
-rw-r--r--src/pkg/net/http/transfer.go630
-rw-r--r--src/pkg/net/http/transport.go736
-rw-r--r--src/pkg/net/http/transport_test.go695
-rw-r--r--src/pkg/net/http/triv.go149
-rw-r--r--src/pkg/net/interface.go38
-rw-r--r--src/pkg/net/interface_bsd.go104
-rw-r--r--src/pkg/net/interface_darwin.go12
-rw-r--r--src/pkg/net/interface_freebsd.go12
-rw-r--r--src/pkg/net/interface_linux.go100
-rw-r--r--src/pkg/net/interface_netbsd.go14
-rw-r--r--src/pkg/net/interface_openbsd.go4
-rw-r--r--src/pkg/net/interface_stub.go8
-rw-r--r--src/pkg/net/interface_test.go62
-rw-r--r--src/pkg/net/interface_windows.go18
-rw-r--r--src/pkg/net/ip.go9
-rw-r--r--src/pkg/net/ip_test.go5
-rw-r--r--src/pkg/net/ipraw_test.go230
-rw-r--r--src/pkg/net/iprawsock.go8
-rw-r--r--src/pkg/net/iprawsock_plan9.go78
-rw-r--r--src/pkg/net/iprawsock_posix.go160
-rw-r--r--src/pkg/net/ipsock.go14
-rw-r--r--src/pkg/net/ipsock_plan9.go54
-rw-r--r--src/pkg/net/ipsock_posix.go30
-rw-r--r--src/pkg/net/lookup_plan9.go72
-rw-r--r--src/pkg/net/lookup_test.go13
-rw-r--r--src/pkg/net/lookup_unix.go87
-rw-r--r--src/pkg/net/lookup_windows.go81
-rw-r--r--src/pkg/net/mail/Makefile11
-rw-r--r--src/pkg/net/mail/message.go524
-rw-r--r--src/pkg/net/mail/message_test.go261
-rw-r--r--src/pkg/net/multicast_test.go130
-rw-r--r--src/pkg/net/net.go119
-rw-r--r--src/pkg/net/net_test.go60
-rw-r--r--src/pkg/net/newpollserver.go11
-rw-r--r--src/pkg/net/parse.go4
-rw-r--r--src/pkg/net/parse_test.go2
-rw-r--r--src/pkg/net/pipe.go21
-rw-r--r--src/pkg/net/pipe_test.go5
-rw-r--r--src/pkg/net/port.go11
-rw-r--r--src/pkg/net/rpc/Makefile13
-rw-r--r--src/pkg/net/rpc/client.go288
-rw-r--r--src/pkg/net/rpc/debug.go90
-rw-r--r--src/pkg/net/rpc/jsonrpc/Makefile12
-rw-r--r--src/pkg/net/rpc/jsonrpc/all_test.go221
-rw-r--r--src/pkg/net/rpc/jsonrpc/client.go123
-rw-r--r--src/pkg/net/rpc/jsonrpc/server.go136
-rw-r--r--src/pkg/net/rpc/server.go640
-rw-r--r--src/pkg/net/rpc/server_test.go596
-rw-r--r--src/pkg/net/sendfile_linux.go17
-rw-r--r--src/pkg/net/sendfile_stub.go9
-rw-r--r--src/pkg/net/sendfile_windows.go4
-rw-r--r--src/pkg/net/server_test.go20
-rw-r--r--src/pkg/net/smtp/Makefile12
-rw-r--r--src/pkg/net/smtp/auth.go98
-rw-r--r--src/pkg/net/smtp/smtp.go293
-rw-r--r--src/pkg/net/smtp/smtp_test.go182
-rw-r--r--src/pkg/net/sock.go117
-rw-r--r--src/pkg/net/sock_bsd.go38
-rw-r--r--src/pkg/net/sock_linux.go32
-rw-r--r--src/pkg/net/sock_windows.go21
-rw-r--r--src/pkg/net/sockopt.go180
-rw-r--r--src/pkg/net/sockopt_bsd.go45
-rw-r--r--src/pkg/net/sockopt_linux.go37
-rw-r--r--src/pkg/net/sockopt_windows.go38
-rw-r--r--src/pkg/net/sockoptip.go187
-rw-r--r--src/pkg/net/sockoptip_bsd.go54
-rw-r--r--src/pkg/net/sockoptip_darwin.go78
-rw-r--r--src/pkg/net/sockoptip_freebsd.go80
-rw-r--r--src/pkg/net/sockoptip_linux.go120
-rw-r--r--src/pkg/net/sockoptip_openbsd.go78
-rw-r--r--src/pkg/net/sockoptip_windows.go61
-rw-r--r--src/pkg/net/tcpsock.go6
-rw-r--r--src/pkg/net/tcpsock_plan9.go42
-rw-r--r--src/pkg/net/tcpsock_posix.go88
-rw-r--r--src/pkg/net/textproto/header.go2
-rw-r--r--src/pkg/net/textproto/pipeline.go2
-rw-r--r--src/pkg/net/textproto/reader.go130
-rw-r--r--src/pkg/net/textproto/reader_test.go35
-rw-r--r--src/pkg/net/textproto/textproto.go11
-rw-r--r--src/pkg/net/textproto/writer.go7
-rw-r--r--src/pkg/net/timeout_test.go83
-rw-r--r--src/pkg/net/udp_test.go87
-rw-r--r--src/pkg/net/udpsock.go6
-rw-r--r--src/pkg/net/udpsock_plan9.go45
-rw-r--r--src/pkg/net/udpsock_posix.go155
-rw-r--r--src/pkg/net/unicast_test.go111
-rw-r--r--src/pkg/net/unixsock.go6
-rw-r--r--src/pkg/net/unixsock_plan9.go31
-rw-r--r--src/pkg/net/unixsock_posix.go118
-rw-r--r--src/pkg/net/url/Makefile11
-rw-r--r--src/pkg/net/url/url.go664
-rw-r--r--src/pkg/net/url/url_test.go771
175 files changed, 22537 insertions, 1873 deletions
diff --git a/src/pkg/net/Makefile b/src/pkg/net/Makefile
index eba9e26d9..a02798c73 100644
--- a/src/pkg/net/Makefile
+++ b/src/pkg/net/Makefile
@@ -9,6 +9,7 @@ GOFILES=\
dial.go\
dnsclient.go\
dnsmsg.go\
+ doc.go\
hosts.go\
interface.go\
ip.go\
@@ -21,14 +22,14 @@ GOFILES=\
udpsock.go\
unixsock.go\
-GOFILES_freebsd=\
+GOFILES_darwin=\
dnsclient_unix.go\
dnsconfig.go\
fd.go\
fd_$(GOOS).go\
file.go\
interface_bsd.go\
- interface_freebsd.go\
+ interface_darwin.go\
iprawsock_posix.go\
ipsock_posix.go\
lookup_unix.go\
@@ -37,26 +38,31 @@ GOFILES_freebsd=\
sendfile_stub.go\
sock.go\
sock_bsd.go\
+ sockopt.go\
+ sockopt_bsd.go\
+ sockoptip.go\
+ sockoptip_bsd.go\
+ sockoptip_darwin.go\
tcpsock_posix.go\
udpsock_posix.go\
unixsock_posix.go\
ifeq ($(CGO_ENABLED),1)
-CGOFILES_freebsd=\
+CGOFILES_darwin=\
cgo_bsd.go\
cgo_unix.go
else
-GOFILES_freebsd+=cgo_stub.go
+GOFILES_darwin+=cgo_stub.go
endif
-GOFILES_darwin=\
+GOFILES_freebsd=\
dnsclient_unix.go\
dnsconfig.go\
fd.go\
fd_$(GOOS).go\
file.go\
interface_bsd.go\
- interface_darwin.go\
+ interface_freebsd.go\
iprawsock_posix.go\
ipsock_posix.go\
lookup_unix.go\
@@ -65,16 +71,21 @@ GOFILES_darwin=\
sendfile_stub.go\
sock.go\
sock_bsd.go\
+ sockopt.go\
+ sockopt_bsd.go\
+ sockoptip.go\
+ sockoptip_bsd.go\
+ sockoptip_freebsd.go\
tcpsock_posix.go\
udpsock_posix.go\
unixsock_posix.go\
ifeq ($(CGO_ENABLED),1)
-CGOFILES_darwin=\
+CGOFILES_freebsd=\
cgo_bsd.go\
cgo_unix.go
else
-GOFILES_darwin+=cgo_stub.go
+GOFILES_freebsd+=cgo_stub.go
endif
GOFILES_linux=\
@@ -92,6 +103,10 @@ GOFILES_linux=\
sendfile_linux.go\
sock.go\
sock_linux.go\
+ sockopt.go\
+ sockopt_linux.go\
+ sockoptip.go\
+ sockoptip_linux.go\
tcpsock_posix.go\
udpsock_posix.go\
unixsock_posix.go\
@@ -104,6 +119,32 @@ else
GOFILES_linux+=cgo_stub.go
endif
+GOFILES_netbsd=\
+ dnsclient_unix.go\
+ dnsconfig.go\
+ fd.go\
+ fd_$(GOOS).go\
+ file.go\
+ interface_bsd.go\
+ interface_netbsd.go\
+ iprawsock_posix.go\
+ ipsock_posix.go\
+ lookup_unix.go\
+ newpollserver.go\
+ port.go\
+ sendfile_stub.go\
+ sock.go\
+ sock_bsd.go\
+ sockopt.go\
+ sockopt_bsd.go\
+ sockoptip.go\
+ sockoptip_bsd.go\
+ sockoptip_netbsd.go\
+ tcpsock_posix.go\
+ udpsock_posix.go\
+ unixsock_posix.go\
+ cgo_stub.go\
+
GOFILES_openbsd=\
dnsclient_unix.go\
dnsconfig.go\
@@ -120,6 +161,11 @@ GOFILES_openbsd=\
sendfile_stub.go\
sock.go\
sock_bsd.go\
+ sockopt.go\
+ sockopt_bsd.go\
+ sockoptip.go\
+ sockoptip_bsd.go\
+ sockoptip_openbsd.go\
tcpsock_posix.go\
udpsock_posix.go\
unixsock_posix.go\
@@ -145,6 +191,10 @@ GOFILES_windows=\
sendfile_windows.go\
sock.go\
sock_windows.go\
+ sockopt.go\
+ sockopt_windows.go\
+ sockoptip.go\
+ sockoptip_windows.go\
tcpsock_posix.go\
udpsock_posix.go\
unixsock_posix.go\
diff --git a/src/pkg/net/cgo_stub.go b/src/pkg/net/cgo_stub.go
index 565cbe7fe..52e57d740 100644
--- a/src/pkg/net/cgo_stub.go
+++ b/src/pkg/net/cgo_stub.go
@@ -2,26 +2,24 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build openbsd
+// +build !cgo
// Stub cgo routines for systems that do not use cgo to do network lookups.
package net
-import "os"
-
-func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) {
+func cgoLookupHost(name string) (addrs []string, err error, completed bool) {
return nil, nil, false
}
-func cgoLookupPort(network, service string) (port int, err os.Error, completed bool) {
+func cgoLookupPort(network, service string) (port int, err error, completed bool) {
return 0, nil, false
}
-func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) {
+func cgoLookupIP(name string) (addrs []IP, err error, completed bool) {
return nil, nil, false
}
-func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) {
+func cgoLookupCNAME(name string) (cname string, err error, completed bool) {
return "", nil, false
}
diff --git a/src/pkg/net/cgo_unix.go b/src/pkg/net/cgo_unix.go
index ec2a393e8..36a3f3d34 100644
--- a/src/pkg/net/cgo_unix.go
+++ b/src/pkg/net/cgo_unix.go
@@ -18,12 +18,11 @@ package net
import "C"
import (
- "os"
"syscall"
"unsafe"
)
-func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) {
+func cgoLookupHost(name string) (addrs []string, err error, completed bool) {
ip, err, completed := cgoLookupIP(name)
for _, p := range ip {
addrs = append(addrs, p.String())
@@ -31,7 +30,7 @@ func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) {
return
}
-func cgoLookupPort(net, service string) (port int, err os.Error, completed bool) {
+func cgoLookupPort(net, service string) (port int, err error, completed bool) {
var res *C.struct_addrinfo
var hints C.struct_addrinfo
@@ -78,7 +77,7 @@ func cgoLookupPort(net, service string) (port int, err os.Error, completed bool)
return 0, &AddrError{"unknown port", net + "/" + service}, true
}
-func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, completed bool) {
+func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err error, completed bool) {
var res *C.struct_addrinfo
var hints C.struct_addrinfo
@@ -98,11 +97,11 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, comp
if gerrno == C.EAI_NONAME {
str = noSuchHost
} else if gerrno == C.EAI_SYSTEM {
- str = err.String()
+ str = err.Error()
} else {
str = C.GoString(C.gai_strerror(gerrno))
}
- return nil, "", &DNSError{Error: str, Name: name}, true
+ return nil, "", &DNSError{Err: str, Name: name}, true
}
defer C.freeaddrinfo(res)
if res != nil {
@@ -133,12 +132,12 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, comp
return addrs, cname, nil, true
}
-func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) {
+func cgoLookupIP(name string) (addrs []IP, err error, completed bool) {
addrs, _, err, completed = cgoLookupIPCNAME(name)
return
}
-func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) {
+func cgoLookupCNAME(name string) (cname string, err error, completed bool) {
_, cname, err, completed = cgoLookupIPCNAME(name)
return
}
diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go
index 10c67dcc4..5d596bcb6 100644
--- a/src/pkg/net/dial.go
+++ b/src/pkg/net/dial.go
@@ -4,26 +4,63 @@
package net
-import "os"
+import (
+ "time"
+)
-func resolveNetAddr(op, net, addr string) (a Addr, err os.Error) {
- if addr == "" {
- return nil, &OpError{op, net, nil, errMissingAddress}
+func parseDialNetwork(net string) (afnet string, proto int, err error) {
+ i := last(net, ':')
+ if i < 0 { // no colon
+ switch net {
+ case "tcp", "tcp4", "tcp6":
+ case "udp", "udp4", "udp6":
+ case "unix", "unixgram", "unixpacket":
+ default:
+ return "", 0, UnknownNetworkError(net)
+ }
+ return net, 0, nil
}
- switch net {
- case "tcp", "tcp4", "tcp6":
- a, err = ResolveTCPAddr(net, addr)
- case "udp", "udp4", "udp6":
- a, err = ResolveUDPAddr(net, addr)
- case "unix", "unixgram", "unixpacket":
- a, err = ResolveUnixAddr(net, addr)
+ afnet = net[:i]
+ switch afnet {
case "ip", "ip4", "ip6":
- a, err = ResolveIPAddr(net, addr)
- default:
- err = UnknownNetworkError(net)
+ protostr := net[i+1:]
+ proto, i, ok := dtoi(protostr, 0)
+ if !ok || i != len(protostr) {
+ proto, err = lookupProtocol(protostr)
+ if err != nil {
+ return "", 0, err
+ }
+ }
+ return afnet, proto, nil
}
+ return "", 0, UnknownNetworkError(net)
+}
+
+func resolveNetAddr(op, net, addr string) (afnet string, a Addr, err error) {
+ afnet, _, err = parseDialNetwork(net)
if err != nil {
- return nil, &OpError{op, net + " " + addr, nil, err}
+ return "", nil, &OpError{op, net, nil, err}
+ }
+ if op == "dial" && addr == "" {
+ return "", nil, &OpError{op, net, nil, errMissingAddress}
+ }
+ switch afnet {
+ case "tcp", "tcp4", "tcp6":
+ if addr != "" {
+ a, err = ResolveTCPAddr(afnet, addr)
+ }
+ case "udp", "udp4", "udp6":
+ if addr != "" {
+ a, err = ResolveUDPAddr(afnet, addr)
+ }
+ case "ip", "ip4", "ip6":
+ if addr != "" {
+ a, err = ResolveIPAddr(afnet, addr)
+ }
+ case "unix", "unixgram", "unixpacket":
+ if addr != "" {
+ a, err = ResolveUnixAddr(afnet, addr)
+ }
}
return
}
@@ -32,122 +69,160 @@ func resolveNetAddr(op, net, addr string) (a Addr, err os.Error) {
//
// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only),
// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4"
-// (IPv4-only), "ip6" (IPv6-only), "unix" and "unixgram".
+// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and "unixpacket".
//
-// For IP networks, addresses have the form host:port. If host is
-// a literal IPv6 address, it must be enclosed in square brackets.
-// The functions JoinHostPort and SplitHostPort manipulate
-// addresses in this form.
+// For TCP and UDP networks, addresses have the form host:port.
+// If host is a literal IPv6 address, it must be enclosed
+// in square brackets. The functions JoinHostPort and SplitHostPort
+// manipulate addresses in this form.
//
// Examples:
// Dial("tcp", "12.34.56.78:80")
// Dial("tcp", "google.com:80")
// Dial("tcp", "[de:ad:be:ef::ca:fe]:80")
//
-func Dial(net, addr string) (c Conn, err os.Error) {
- addri, err := resolveNetAddr("dial", net, addr)
+// For IP networks, addr must be "ip", "ip4" or "ip6" followed
+// by a colon and a protocol number or name.
+//
+// Examples:
+// Dial("ip4:1", "127.0.0.1")
+// Dial("ip6:ospf", "::1")
+//
+func Dial(net, addr string) (Conn, error) {
+ _, addri, err := resolveNetAddr("dial", net, addr)
if err != nil {
return nil, err
}
+ return dialAddr(net, addr, addri)
+}
+
+func dialAddr(net, addr string, addri Addr) (c Conn, err error) {
switch ra := addri.(type) {
case *TCPAddr:
c, err = DialTCP(net, nil, ra)
case *UDPAddr:
c, err = DialUDP(net, nil, ra)
- case *UnixAddr:
- c, err = DialUnix(net, nil, ra)
case *IPAddr:
c, err = DialIP(net, nil, ra)
+ case *UnixAddr:
+ c, err = DialUnix(net, nil, ra)
default:
- err = UnknownNetworkError(net)
+ err = &OpError{"dial", net + " " + addr, nil, UnknownNetworkError(net)}
}
if err != nil {
- return nil, &OpError{"dial", net + " " + addr, nil, err}
+ return nil, err
}
return
}
+// DialTimeout acts like Dial but takes a timeout.
+// The timeout includes name resolution, if required.
+func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+ // TODO(bradfitz): the timeout should be pushed down into the
+ // net package's event loop, so on timeout to dead hosts we
+ // don't have a goroutine sticking around for the default of
+ // ~3 minutes.
+ t := time.NewTimer(timeout)
+ defer t.Stop()
+ type pair struct {
+ Conn
+ error
+ }
+ ch := make(chan pair, 1)
+ resolvedAddr := make(chan Addr, 1)
+ go func() {
+ _, addri, err := resolveNetAddr("dial", net, addr)
+ if err != nil {
+ ch <- pair{nil, err}
+ return
+ }
+ resolvedAddr <- addri // in case we need it for OpError
+ c, err := dialAddr(net, addr, addri)
+ ch <- pair{c, err}
+ }()
+ select {
+ case <-t.C:
+ // Try to use the real Addr in our OpError, if we resolved it
+ // before the timeout. Otherwise we just use stringAddr.
+ var addri Addr
+ select {
+ case a := <-resolvedAddr:
+ addri = a
+ default:
+ addri = &stringAddr{net, addr}
+ }
+ err := &OpError{
+ Op: "dial",
+ Net: net,
+ Addr: addri,
+ Err: &timeoutError{},
+ }
+ return nil, err
+ case p := <-ch:
+ return p.Conn, p.error
+ }
+ panic("unreachable")
+}
+
+type stringAddr struct {
+ net, addr string
+}
+
+func (a stringAddr) Network() string { return a.net }
+func (a stringAddr) String() string { return a.addr }
+
// Listen announces on the local network address laddr.
-// The network string net must be a stream-oriented
-// network: "tcp", "tcp4", "tcp6", or "unix", or "unixpacket".
-func Listen(net, laddr string) (l Listener, err os.Error) {
- switch net {
+// The network string net must be a stream-oriented network:
+// "tcp", "tcp4", "tcp6", or "unix", or "unixpacket".
+func Listen(net, laddr string) (Listener, error) {
+ afnet, a, err := resolveNetAddr("listen", net, laddr)
+ if err != nil {
+ return nil, err
+ }
+ switch afnet {
case "tcp", "tcp4", "tcp6":
var la *TCPAddr
- if laddr != "" {
- if la, err = ResolveTCPAddr(net, laddr); err != nil {
- return nil, err
- }
- }
- l, err := ListenTCP(net, la)
- if err != nil {
- return nil, err
+ if a != nil {
+ la = a.(*TCPAddr)
}
- return l, nil
+ return ListenTCP(afnet, la)
case "unix", "unixpacket":
var la *UnixAddr
- if laddr != "" {
- if la, err = ResolveUnixAddr(net, laddr); err != nil {
- return nil, err
- }
- }
- l, err := ListenUnix(net, la)
- if err != nil {
- return nil, err
+ if a != nil {
+ la = a.(*UnixAddr)
}
- return l, nil
+ return ListenUnix(net, la)
}
return nil, UnknownNetworkError(net)
}
// ListenPacket announces on the local network address laddr.
// The network string net must be a packet-oriented network:
-// "udp", "udp4", "udp6", or "unixgram".
-func ListenPacket(net, laddr string) (c PacketConn, err os.Error) {
- switch net {
+// "udp", "udp4", "udp6", "ip", "ip4", "ip6" or "unixgram".
+func ListenPacket(net, addr string) (PacketConn, error) {
+ afnet, a, err := resolveNetAddr("listen", net, addr)
+ if err != nil {
+ return nil, err
+ }
+ switch afnet {
case "udp", "udp4", "udp6":
var la *UDPAddr
- if laddr != "" {
- if la, err = ResolveUDPAddr(net, laddr); err != nil {
- return nil, err
- }
+ if a != nil {
+ la = a.(*UDPAddr)
}
- c, err := ListenUDP(net, la)
- if err != nil {
- return nil, err
+ return ListenUDP(net, la)
+ case "ip", "ip4", "ip6":
+ var la *IPAddr
+ if a != nil {
+ la = a.(*IPAddr)
}
- return c, nil
+ return ListenIP(net, la)
case "unixgram":
var la *UnixAddr
- if laddr != "" {
- if la, err = ResolveUnixAddr(net, laddr); err != nil {
- return nil, err
- }
- }
- c, err := DialUnix(net, la, nil)
- if err != nil {
- return nil, err
- }
- return c, nil
- }
-
- var rawnet string
- if rawnet, _, err = splitNetProto(net); err != nil {
- switch rawnet {
- case "ip", "ip4", "ip6":
- var la *IPAddr
- if laddr != "" {
- if la, err = ResolveIPAddr(rawnet, laddr); err != nil {
- return nil, err
- }
- }
- c, err := ListenIP(net, la)
- if err != nil {
- return nil, err
- }
- return c, nil
+ if a != nil {
+ la = a.(*UnixAddr)
}
+ return DialUnix(net, la, nil)
}
-
return nil, UnknownNetworkError(net)
}
diff --git a/src/pkg/net/dial_test.go b/src/pkg/net/dial_test.go
new file mode 100644
index 000000000..16b726311
--- /dev/null
+++ b/src/pkg/net/dial_test.go
@@ -0,0 +1,88 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "runtime"
+ "testing"
+ "time"
+)
+
+func newLocalListener(t *testing.T) Listener {
+ ln, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = Listen("tcp6", "[::1]:0")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
+func TestDialTimeout(t *testing.T) {
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ errc := make(chan error)
+
+ const SOMAXCONN = 0x80 // copied from syscall, but not always available
+ const numConns = SOMAXCONN + 10
+
+ // TODO(bradfitz): It's hard to test this in a portable
+ // way. This is unforunate, but works for now.
+ switch runtime.GOOS {
+ case "linux":
+ // The kernel will start accepting TCP connections before userspace
+ // gets a chance to not accept them, so fire off a bunch to fill up
+ // the kernel's backlog. Then we test we get a failure after that.
+ for i := 0; i < numConns; i++ {
+ go func() {
+ _, err := DialTimeout("tcp", ln.Addr().String(), 200*time.Millisecond)
+ errc <- err
+ }()
+ }
+ case "darwin":
+ // At least OS X 10.7 seems to accept any number of
+ // connections, ignoring listen's backlog, so resort
+ // to connecting to a hopefully-dead 127/8 address.
+ go func() {
+ _, err := DialTimeout("tcp", "127.0.71.111:80", 200*time.Millisecond)
+ errc <- err
+ }()
+ default:
+ // TODO(bradfitz): this probably doesn't work on
+ // Windows? SOMAXCONN is huge there. I'm not sure how
+ // listen works there.
+ // OpenBSD may have a reject route to 10/8.
+ // FreeBSD likely works, but is untested.
+ t.Logf("skipping test on %q; untested.", runtime.GOOS)
+ return
+ }
+
+ connected := 0
+ for {
+ select {
+ case <-time.After(15 * time.Second):
+ t.Fatal("too slow")
+ case err := <-errc:
+ if err == nil {
+ connected++
+ if connected == numConns {
+ t.Fatal("all connections connected; expected some to time out")
+ }
+ } else {
+ terr, ok := err.(timeout)
+ if !ok {
+ t.Fatalf("got error %q; want error with timeout interface", err)
+ }
+ if !terr.Timeout() {
+ t.Fatalf("got error %q; not a timeout", err)
+ }
+ // Pass. We saw a timeout error.
+ return
+ }
+ }
+ }
+}
diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go
index 9ad1770da..81750a3d7 100644
--- a/src/pkg/net/dialgoogle_test.go
+++ b/src/pkg/net/dialgoogle_test.go
@@ -19,7 +19,7 @@ var ipv6 = flag.Bool("ipv6", false, "assume ipv6 tunnel is present")
// fd is already connected to the destination, port 80.
// Run an HTTP request to fetch the appropriate page.
func fetchGoogle(t *testing.T, fd Conn, network, addr string) {
- req := []byte("GET /intl/en/privacy/ HTTP/1.0\r\nHost: www.google.com\r\n\r\n")
+ req := []byte("GET /robots.txt HTTP/1.0\r\nHost: www.google.com\r\n\r\n")
n, err := fd.Write(req)
buf := make([]byte, 1000)
diff --git a/src/pkg/net/dict/Makefile b/src/pkg/net/dict/Makefile
deleted file mode 100644
index eaa9e6531..000000000
--- a/src/pkg/net/dict/Makefile
+++ /dev/null
@@ -1,7 +0,0 @@
-include ../../../Make.inc
-
-TARG=net/dict
-GOFILES=\
- dict.go\
-
-include ../../../Make.pkg
diff --git a/src/pkg/net/dict/dict.go b/src/pkg/net/dict/dict.go
deleted file mode 100644
index b146ea212..000000000
--- a/src/pkg/net/dict/dict.go
+++ /dev/null
@@ -1,211 +0,0 @@
-// Copyright 2010 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package dict implements the Dictionary Server Protocol
-// as defined in RFC 2229.
-package dict
-
-import (
- "net/textproto"
- "os"
- "strconv"
- "strings"
-)
-
-// A Client represents a client connection to a dictionary server.
-type Client struct {
- text *textproto.Conn
-}
-
-// Dial returns a new client connected to a dictionary server at
-// addr on the given network.
-func Dial(network, addr string) (*Client, os.Error) {
- text, err := textproto.Dial(network, addr)
- if err != nil {
- return nil, err
- }
- _, _, err = text.ReadCodeLine(220)
- if err != nil {
- text.Close()
- return nil, err
- }
- return &Client{text: text}, nil
-}
-
-// Close closes the connection to the dictionary server.
-func (c *Client) Close() os.Error {
- return c.text.Close()
-}
-
-// A Dict represents a dictionary available on the server.
-type Dict struct {
- Name string // short name of dictionary
- Desc string // long description
-}
-
-// Dicts returns a list of the dictionaries available on the server.
-func (c *Client) Dicts() ([]Dict, os.Error) {
- id, err := c.text.Cmd("SHOW DB")
- if err != nil {
- return nil, err
- }
-
- c.text.StartResponse(id)
- defer c.text.EndResponse(id)
-
- _, _, err = c.text.ReadCodeLine(110)
- if err != nil {
- return nil, err
- }
- lines, err := c.text.ReadDotLines()
- if err != nil {
- return nil, err
- }
- _, _, err = c.text.ReadCodeLine(250)
-
- dicts := make([]Dict, len(lines))
- for i := range dicts {
- d := &dicts[i]
- a, _ := fields(lines[i])
- if len(a) < 2 {
- return nil, textproto.ProtocolError("invalid dictionary: " + lines[i])
- }
- d.Name = a[0]
- d.Desc = a[1]
- }
- return dicts, err
-}
-
-// A Defn represents a definition.
-type Defn struct {
- Dict Dict // Dict where definition was found
- Word string // Word being defined
- Text []byte // Definition text, typically multiple lines
-}
-
-// Define requests the definition of the given word.
-// The argument dict names the dictionary to use,
-// the Name field of a Dict returned by Dicts.
-//
-// The special dictionary name "*" means to look in all the
-// server's dictionaries.
-// The special dictionary name "!" means to look in all the
-// server's dictionaries in turn, stopping after finding the word
-// in one of them.
-func (c *Client) Define(dict, word string) ([]*Defn, os.Error) {
- id, err := c.text.Cmd("DEFINE %s %q", dict, word)
- if err != nil {
- return nil, err
- }
-
- c.text.StartResponse(id)
- defer c.text.EndResponse(id)
-
- _, line, err := c.text.ReadCodeLine(150)
- if err != nil {
- return nil, err
- }
- a, _ := fields(line)
- if len(a) < 1 {
- return nil, textproto.ProtocolError("malformed response: " + line)
- }
- n, err := strconv.Atoi(a[0])
- if err != nil {
- return nil, textproto.ProtocolError("invalid definition count: " + a[0])
- }
- def := make([]*Defn, n)
- for i := 0; i < n; i++ {
- _, line, err = c.text.ReadCodeLine(151)
- if err != nil {
- return nil, err
- }
- a, _ := fields(line)
- if len(a) < 3 {
- // skip it, to keep protocol in sync
- i--
- n--
- def = def[0:n]
- continue
- }
- d := &Defn{Word: a[0], Dict: Dict{a[1], a[2]}}
- d.Text, err = c.text.ReadDotBytes()
- if err != nil {
- return nil, err
- }
- def[i] = d
- }
- _, _, err = c.text.ReadCodeLine(250)
- return def, err
-}
-
-// Fields returns the fields in s.
-// Fields are space separated unquoted words
-// or quoted with single or double quote.
-func fields(s string) ([]string, os.Error) {
- var v []string
- i := 0
- for {
- for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
- i++
- }
- if i >= len(s) {
- break
- }
- if s[i] == '"' || s[i] == '\'' {
- q := s[i]
- // quoted string
- var j int
- for j = i + 1; ; j++ {
- if j >= len(s) {
- return nil, textproto.ProtocolError("malformed quoted string")
- }
- if s[j] == '\\' {
- j++
- continue
- }
- if s[j] == q {
- j++
- break
- }
- }
- v = append(v, unquote(s[i+1:j-1]))
- i = j
- } else {
- // atom
- var j int
- for j = i; j < len(s); j++ {
- if s[j] == ' ' || s[j] == '\t' || s[j] == '\\' || s[j] == '"' || s[j] == '\'' {
- break
- }
- }
- v = append(v, s[i:j])
- i = j
- }
- if i < len(s) {
- c := s[i]
- if c != ' ' && c != '\t' {
- return nil, textproto.ProtocolError("quotes not on word boundaries")
- }
- }
- }
- return v, nil
-}
-
-func unquote(s string) string {
- if strings.Index(s, "\\") < 0 {
- return s
- }
- b := []byte(s)
- w := 0
- for r := 0; r < len(b); r++ {
- c := b[r]
- if c == '\\' {
- r++
- c = b[r]
- }
- b[w] = c
- w++
- }
- return string(b[0:w])
-}
diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go
index 93c04f6b5..f4ed8b87c 100644
--- a/src/pkg/net/dnsclient.go
+++ b/src/pkg/net/dnsclient.go
@@ -7,20 +7,19 @@ package net
import (
"bytes"
"fmt"
- "os"
- "rand"
+ "math/rand"
"sort"
)
// DNSError represents a DNS lookup error.
type DNSError struct {
- Error string // description of the error
+ Err string // description of the error
Name string // name looked for
Server string // server used
IsTimeout bool
}
-func (e *DNSError) String() string {
+func (e *DNSError) Error() string {
if e == nil {
return "<nil>"
}
@@ -28,7 +27,7 @@ func (e *DNSError) String() string {
if e.Server != "" {
s += " on " + e.Server
}
- s += ": " + e.Error
+ s += ": " + e.Err
return s
}
@@ -40,10 +39,10 @@ const noSuchHost = "no such host"
// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
// address addr suitable for rDNS (PTR) record lookup or an error if it fails
// to parse the IP address.
-func reverseaddr(addr string) (arpa string, err os.Error) {
+func reverseaddr(addr string) (arpa string, err error) {
ip := ParseIP(addr)
if ip == nil {
- return "", &DNSError{Error: "unrecognized address", Name: addr}
+ return "", &DNSError{Err: "unrecognized address", Name: addr}
}
if ip.To4() != nil {
return fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa.", ip[15], ip[14], ip[13], ip[12]), nil
@@ -64,18 +63,18 @@ func reverseaddr(addr string) (arpa string, err os.Error) {
// Find answer for name in dns message.
// On return, if err == nil, addrs != nil.
-func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err os.Error) {
+func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err error) {
addrs = make([]dnsRR, 0, len(dns.answer))
if dns.rcode == dnsRcodeNameError && dns.recursion_available {
- return "", nil, &DNSError{Error: noSuchHost, Name: name}
+ return "", nil, &DNSError{Err: noSuchHost, Name: name}
}
if dns.rcode != dnsRcodeSuccess {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
// the server is behaving incorrectly.
- return "", nil, &DNSError{Error: "server misbehaving", Name: name, Server: server}
+ return "", nil, &DNSError{Err: "server misbehaving", Name: name, Server: server}
}
// Look for the name.
@@ -107,12 +106,12 @@ Cname:
}
}
if len(addrs) == 0 {
- return "", nil, &DNSError{Error: noSuchHost, Name: name, Server: server}
+ return "", nil, &DNSError{Err: noSuchHost, Name: name, Server: server}
}
return name, addrs, nil
}
- return "", nil, &DNSError{Error: "too many redirects", Name: name, Server: server}
+ return "", nil, &DNSError{Err: "too many redirects", Name: name, Server: server}
}
func isDomainName(s string) bool {
diff --git a/src/pkg/net/dnsclient_unix.go b/src/pkg/net/dnsclient_unix.go
index eb7db5e27..18c39360e 100644
--- a/src/pkg/net/dnsclient_unix.go
+++ b/src/pkg/net/dnsclient_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd
+// +build darwin freebsd linux netbsd openbsd
// DNS client: see RFC 1035.
// Has to be linked into package net for Dial.
@@ -17,27 +17,26 @@
package net
import (
- "os"
- "rand"
+ "math/rand"
"sync"
"time"
)
// Send a request on the connection and hope for a reply.
// Up to cfg.attempts attempts.
-func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, os.Error) {
+func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error) {
if len(name) >= 256 {
- return nil, &DNSError{Error: "name too long", Name: name}
+ return nil, &DNSError{Err: "name too long", Name: name}
}
out := new(dnsMsg)
- out.id = uint16(rand.Int()) ^ uint16(time.Nanoseconds())
+ out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
out.question = []dnsQuestion{
{name, qtype, dnsClassINET},
}
out.recursion_desired = true
msg, ok := out.Pack()
if !ok {
- return nil, &DNSError{Error: "internal error - cannot pack message", Name: name}
+ return nil, &DNSError{Err: "internal error - cannot pack message", Name: name}
}
for attempt := 0; attempt < cfg.attempts; attempt++ {
@@ -46,7 +45,11 @@ func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, os.Er
return nil, err
}
- c.SetReadTimeout(int64(cfg.timeout) * 1e9) // nanoseconds
+ if cfg.timeout == 0 {
+ c.SetReadDeadline(time.Time{})
+ } else {
+ c.SetReadDeadline(time.Now().Add(time.Duration(cfg.timeout) * time.Second))
+ }
buf := make([]byte, 2000) // More than enough.
n, err = c.Read(buf)
@@ -67,14 +70,14 @@ func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, os.Er
if a := c.RemoteAddr(); a != nil {
server = a.String()
}
- return nil, &DNSError{Error: "no answer from server", Name: name, Server: server, IsTimeout: true}
+ return nil, &DNSError{Err: "no answer from server", Name: name, Server: server, IsTimeout: true}
}
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
-func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err os.Error) {
+func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err error) {
if len(cfg.servers) == 0 {
- return "", nil, &DNSError{Error: "no DNS servers", Name: name}
+ return "", nil, &DNSError{Err: "no DNS servers", Name: name}
}
for i := 0; i < len(cfg.servers); i++ {
// Calling Dial here is scary -- we have to be sure
@@ -96,7 +99,7 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs
continue
}
cname, addrs, err = answer(name, server, msg, qtype)
- if err == nil || err.(*DNSError).Error == noSuchHost {
+ if err == nil || err.(*DNSError).Err == noSuchHost {
break
}
}
@@ -123,15 +126,15 @@ func convertRR_AAAA(records []dnsRR) []IP {
}
var cfg *dnsConfig
-var dnserr os.Error
+var dnserr error
func loadConfig() { cfg, dnserr = dnsReadConfig() }
var onceLoadConfig sync.Once
-func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Error) {
+func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err error) {
if !isDomainName(name) {
- return name, nil, &DNSError{Error: "invalid domain name", Name: name}
+ return name, nil, &DNSError{Err: "invalid domain name", Name: name}
}
onceLoadConfig.Do(loadConfig)
if dnserr != nil || cfg == nil {
@@ -186,7 +189,7 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Erro
// Normally we let cgo use the C library resolver instead of
// depending on our lookup code, so that Go and C get the same
// answers.
-func goLookupHost(name string) (addrs []string, err os.Error) {
+func goLookupHost(name string) (addrs []string, err error) {
// Use entries from /etc/hosts if they match.
addrs = lookupStaticHost(name)
if len(addrs) > 0 {
@@ -214,7 +217,7 @@ func goLookupHost(name string) (addrs []string, err os.Error) {
// Normally we let cgo use the C library resolver instead of
// depending on our lookup code, so that Go and C get the same
// answers.
-func goLookupIP(name string) (addrs []IP, err os.Error) {
+func goLookupIP(name string) (addrs []IP, err error) {
// Use entries from /etc/hosts if possible.
haddrs := lookupStaticHost(name)
if len(haddrs) > 0 {
@@ -260,7 +263,7 @@ func goLookupIP(name string) (addrs []IP, err os.Error) {
// Normally we let cgo use the C library resolver instead of
// depending on our lookup code, so that Go and C get the same
// answers.
-func goLookupCNAME(name string) (cname string, err os.Error) {
+func goLookupCNAME(name string) (cname string, err error) {
onceLoadConfig.Do(loadConfig)
if dnserr != nil || cfg == nil {
err = dnserr
diff --git a/src/pkg/net/dnsconfig.go b/src/pkg/net/dnsconfig.go
index afc059917..c0ab80288 100644
--- a/src/pkg/net/dnsconfig.go
+++ b/src/pkg/net/dnsconfig.go
@@ -2,14 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd
+// +build darwin freebsd linux netbsd openbsd
// Read system DNS config from /etc/resolv.conf
package net
-import "os"
-
type dnsConfig struct {
servers []string // servers to use
search []string // suffixes to append to local name
@@ -19,14 +17,14 @@ type dnsConfig struct {
rotate bool // round robin among servers
}
-var dnsconfigError os.Error
+var dnsconfigError error
type DNSConfigError struct {
- Error os.Error
+ Err error
}
-func (e *DNSConfigError) String() string {
- return "error reading DNS config: " + e.Error.String()
+func (e *DNSConfigError) Error() string {
+ return "error reading DNS config: " + e.Err.Error()
}
func (e *DNSConfigError) Timeout() bool { return false }
@@ -36,7 +34,7 @@ func (e *DNSConfigError) Temporary() bool { return false }
// TODO(rsc): Supposed to call uname() and chop the beginning
// of the host name to get the default search domain.
// We assume it's in resolv.conf anyway.
-func dnsReadConfig() (*dnsConfig, os.Error) {
+func dnsReadConfig() (*dnsConfig, error) {
file, err := open("/etc/resolv.conf")
if err != nil {
return nil, &DNSConfigError{err}
diff --git a/src/pkg/net/doc.go b/src/pkg/net/doc.go
new file mode 100644
index 000000000..3a44e528e
--- /dev/null
+++ b/src/pkg/net/doc.go
@@ -0,0 +1,59 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+// LookupHost looks up the given host using the local resolver.
+// It returns an array of that host's addresses.
+func LookupHost(host string) (addrs []string, err error) {
+ return lookupHost(host)
+}
+
+// LookupIP looks up host using the local resolver.
+// It returns an array of that host's IPv4 and IPv6 addresses.
+func LookupIP(host string) (addrs []IP, err error) {
+ return lookupIP(host)
+}
+
+// LookupPort looks up the port for the given network and service.
+func LookupPort(network, service string) (port int, err error) {
+ return lookupPort(network, service)
+}
+
+// LookupCNAME returns the canonical DNS host for the given name.
+// Callers that do not care about the canonical name can call
+// LookupHost or LookupIP directly; both take care of resolving
+// the canonical name as part of the lookup.
+func LookupCNAME(name string) (cname string, err error) {
+ return lookupCNAME(name)
+}
+
+// LookupSRV tries to resolve an SRV query of the given service,
+// protocol, and domain name. The proto is "tcp" or "udp".
+// The returned records are sorted by priority and randomized
+// by weight within a priority.
+//
+// LookupSRV constructs the DNS name to look up following RFC 2782.
+// That is, it looks up _service._proto.name. To accommodate services
+// publishing SRV records under non-standard names, if both service
+// and proto are empty strings, LookupSRV looks up name directly.
+func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
+ return lookupSRV(service, proto, name)
+}
+
+// LookupMX returns the DNS MX records for the given domain name sorted by preference.
+func LookupMX(name string) (mx []*MX, err error) {
+ return lookupMX(name)
+}
+
+// LookupTXT returns the DNS TXT records for the given domain name.
+func LookupTXT(name string) (txt []string, err error) {
+ return lookupTXT(name)
+}
+
+// LookupAddr performs a reverse lookup for the given address, returning a list
+// of names mapping to that address.
+func LookupAddr(addr string) (name []string, err error) {
+ return lookupAddr(addr)
+}
diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go
index 9084e8875..495ef007f 100644
--- a/src/pkg/net/fd.go
+++ b/src/pkg/net/fd.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd
+// +build darwin freebsd linux netbsd openbsd
package net
@@ -22,23 +22,22 @@ type netFD struct {
closing bool
// immutable until Close
- sysfd int
- family int
- proto int
- sysfile *os.File
- cr chan bool
- cw chan bool
- net string
- laddr Addr
- raddr Addr
+ sysfd int
+ family int
+ sotype int
+ isConnected bool
+ sysfile *os.File
+ cr chan bool
+ cw chan bool
+ net string
+ laddr Addr
+ raddr Addr
// owned by client
- rdeadline_delta int64
- rdeadline int64
- rio sync.Mutex
- wdeadline_delta int64
- wdeadline int64
- wio sync.Mutex
+ rdeadline int64
+ rio sync.Mutex
+ wdeadline int64
+ wio sync.Mutex
// owned by fd wait server
ncr, ncw int
@@ -46,7 +45,7 @@ type netFD struct {
type InvalidConnError struct{}
-func (e *InvalidConnError) String() string { return "invalid net.Conn" }
+func (e *InvalidConnError) Error() string { return "invalid net.Conn" }
func (e *InvalidConnError) Temporary() bool { return false }
func (e *InvalidConnError) Timeout() bool { return false }
@@ -126,7 +125,7 @@ func (s *pollServer) AddFD(fd *netFD, mode int) {
wake, err := s.poll.AddFD(intfd, mode, false)
if err != nil {
- panic("pollServer AddFD " + err.String())
+ panic("pollServer AddFD " + err.Error())
}
if wake {
doWakeup = true
@@ -152,7 +151,7 @@ func (s *pollServer) LookupFD(fd int, mode int) *netFD {
if !ok {
return nil
}
- s.pending[key] = nil, false
+ delete(s.pending, key)
return netfd
}
@@ -171,7 +170,7 @@ func (s *pollServer) WakeFD(fd *netFD, mode int) {
}
func (s *pollServer) Now() int64 {
- return time.Nanoseconds()
+ return time.Now().UnixNano()
}
func (s *pollServer) CheckDeadlines() {
@@ -195,7 +194,7 @@ func (s *pollServer) CheckDeadlines() {
}
if t > 0 {
if t <= now {
- s.pending[key] = nil, false
+ delete(s.pending, key)
if mode == 'r' {
s.poll.DelFD(fd.sysfd, mode)
fd.rdeadline = -1
@@ -227,7 +226,7 @@ func (s *pollServer) Run() {
}
fd, mode, err := s.poll.WaitFD(s, t)
if err != nil {
- print("pollServer WaitFD: ", err.String(), "\n")
+ print("pollServer WaitFD: ", err.Error(), "\n")
return
}
if fd < 0 {
@@ -271,20 +270,20 @@ var onceStartServer sync.Once
func startServer() {
p, err := newPollServer()
if err != nil {
- print("Start pollServer: ", err.String(), "\n")
+ print("Start pollServer: ", err.Error(), "\n")
}
pollserver = p
}
-func newFD(fd, family, proto int, net string) (f *netFD, err os.Error) {
+func newFD(fd, family, sotype int, net string) (f *netFD, err error) {
onceStartServer.Do(startServer)
- if e := syscall.SetNonblock(fd, true); e != 0 {
- return nil, os.Errno(e)
+ if e := syscall.SetNonblock(fd, true); e != nil {
+ return nil, e
}
f = &netFD{
sysfd: fd,
family: family,
- proto: proto,
+ sotype: sotype,
net: net,
}
f.cr = make(chan bool, 1)
@@ -305,20 +304,20 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
fd.sysfile = os.NewFile(fd.sysfd, fd.net+":"+ls+"->"+rs)
}
-func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) {
- e := syscall.Connect(fd.sysfd, ra)
- if e == syscall.EINPROGRESS {
- var errno int
+func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
+ err = syscall.Connect(fd.sysfd, ra)
+ if err == syscall.EINPROGRESS {
pollserver.WaitWrite(fd)
- e, errno = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
- if errno != 0 {
- return os.NewSyscallError("getsockopt", errno)
+ var e int
+ e, err = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
+ if err != nil {
+ return os.NewSyscallError("getsockopt", err)
+ }
+ if e != 0 {
+ err = syscall.Errno(e)
}
}
- if e != 0 {
- return os.Errno(e)
- }
- return nil
+ return err
}
// Add a reference to this fd.
@@ -346,7 +345,7 @@ func (fd *netFD) decref() {
fd.sysmu.Unlock()
}
-func (fd *netFD) Close() os.Error {
+func (fd *netFD) Close() error {
if fd == nil || fd.sysfile == nil {
return os.EINVAL
}
@@ -358,7 +357,26 @@ func (fd *netFD) Close() os.Error {
return nil
}
-func (fd *netFD) Read(p []byte) (n int, err os.Error) {
+func (fd *netFD) shutdown(how int) error {
+ if fd == nil || fd.sysfile == nil {
+ return os.EINVAL
+ }
+ err := syscall.Shutdown(fd.sysfd, how)
+ if err != nil {
+ return &OpError{"shutdown", fd.net, fd.laddr, err}
+ }
+ return nil
+}
+
+func (fd *netFD) CloseRead() error {
+ return fd.shutdown(syscall.SHUT_RD)
+}
+
+func (fd *netFD) CloseWrite() error {
+ return fd.shutdown(syscall.SHUT_WR)
+}
+
+func (fd *netFD) Read(p []byte) (n int, err error) {
if fd == nil {
return 0, os.EINVAL
}
@@ -369,34 +387,29 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) {
if fd.sysfile == nil {
return 0, os.EINVAL
}
- if fd.rdeadline_delta > 0 {
- fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
- } else {
- fd.rdeadline = 0
- }
- var oserr os.Error
for {
- var errno int
- n, errno = syscall.Read(fd.sysfile.Fd(), p)
- if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
- pollserver.WaitRead(fd)
- continue
+ n, err = syscall.Read(fd.sysfile.Fd(), p)
+ if err == syscall.EAGAIN {
+ if fd.rdeadline >= 0 {
+ pollserver.WaitRead(fd)
+ continue
+ }
+ err = errTimeout
}
- if errno != 0 {
+ if err != nil {
n = 0
- oserr = os.Errno(errno)
- } else if n == 0 && errno == 0 && fd.proto != syscall.SOCK_DGRAM {
- err = os.EOF
+ } else if n == 0 && err == nil && fd.sotype != syscall.SOCK_DGRAM {
+ err = io.EOF
}
break
}
- if oserr != nil {
- err = &OpError{"read", fd.net, fd.raddr, oserr}
+ if err != nil && err != io.EOF {
+ err = &OpError{"read", fd.net, fd.raddr, err}
}
return
}
-func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
+func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
if fd == nil || fd.sysfile == nil {
return 0, nil, os.EINVAL
}
@@ -404,32 +417,27 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
defer fd.rio.Unlock()
fd.incref()
defer fd.decref()
- if fd.rdeadline_delta > 0 {
- fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
- } else {
- fd.rdeadline = 0
- }
- var oserr os.Error
for {
- var errno int
- n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0)
- if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
- pollserver.WaitRead(fd)
- continue
+ n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
+ if err == syscall.EAGAIN {
+ if fd.rdeadline >= 0 {
+ pollserver.WaitRead(fd)
+ continue
+ }
+ err = errTimeout
}
- if errno != 0 {
+ if err != nil {
n = 0
- oserr = os.Errno(errno)
}
break
}
- if oserr != nil {
- err = &OpError{"read", fd.net, fd.laddr, oserr}
+ if err != nil {
+ err = &OpError{"read", fd.net, fd.laddr, err}
}
return
}
-func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) {
+func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
if fd == nil || fd.sysfile == nil {
return 0, 0, 0, nil, os.EINVAL
}
@@ -437,35 +445,28 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
defer fd.rio.Unlock()
fd.incref()
defer fd.decref()
- if fd.rdeadline_delta > 0 {
- fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
- } else {
- fd.rdeadline = 0
- }
- var oserr os.Error
for {
- var errno int
- n, oobn, flags, sa, errno = syscall.Recvmsg(fd.sysfd, p, oob, 0)
- if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
- pollserver.WaitRead(fd)
- continue
- }
- if errno != 0 {
- oserr = os.Errno(errno)
+ n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
+ if err == syscall.EAGAIN {
+ if fd.rdeadline >= 0 {
+ pollserver.WaitRead(fd)
+ continue
+ }
+ err = errTimeout
}
- if n == 0 {
- oserr = os.EOF
+ if err == nil && n == 0 {
+ err = io.EOF
}
break
}
- if oserr != nil {
- err = &OpError{"read", fd.net, fd.laddr, oserr}
+ if err != nil && err != io.EOF {
+ err = &OpError{"read", fd.net, fd.laddr, err}
return
}
return
}
-func (fd *netFD) Write(p []byte) (n int, err os.Error) {
+func (fd *netFD) Write(p []byte) (n int, err error) {
if fd == nil {
return 0, os.EINVAL
}
@@ -476,43 +477,40 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) {
if fd.sysfile == nil {
return 0, os.EINVAL
}
- if fd.wdeadline_delta > 0 {
- fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
- } else {
- fd.wdeadline = 0
- }
nn := 0
- var oserr os.Error
for {
- n, errno := syscall.Write(fd.sysfile.Fd(), p[nn:])
+ var n int
+ n, err = syscall.Write(fd.sysfile.Fd(), p[nn:])
if n > 0 {
nn += n
}
if nn == len(p) {
break
}
- if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
- pollserver.WaitWrite(fd)
- continue
+ if err == syscall.EAGAIN {
+ if fd.wdeadline >= 0 {
+ pollserver.WaitWrite(fd)
+ continue
+ }
+ err = errTimeout
}
- if errno != 0 {
+ if err != nil {
n = 0
- oserr = os.Errno(errno)
break
}
if n == 0 {
- oserr = io.ErrUnexpectedEOF
+ err = io.ErrUnexpectedEOF
break
}
}
- if oserr != nil {
- err = &OpError{"write", fd.net, fd.raddr, oserr}
+ if err != nil {
+ err = &OpError{"write", fd.net, fd.raddr, err}
}
return nn, err
}
-func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
+func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
if fd == nil || fd.sysfile == nil {
return 0, os.EINVAL
}
@@ -520,32 +518,26 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
defer fd.wio.Unlock()
fd.incref()
defer fd.decref()
- if fd.wdeadline_delta > 0 {
- fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
- } else {
- fd.wdeadline = 0
- }
- var oserr os.Error
for {
- errno := syscall.Sendto(fd.sysfd, p, 0, sa)
- if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
- pollserver.WaitWrite(fd)
- continue
- }
- if errno != 0 {
- oserr = os.Errno(errno)
+ err = syscall.Sendto(fd.sysfd, p, 0, sa)
+ if err == syscall.EAGAIN {
+ if fd.wdeadline >= 0 {
+ pollserver.WaitWrite(fd)
+ continue
+ }
+ err = errTimeout
}
break
}
- if oserr == nil {
+ if err == nil {
n = len(p)
} else {
- err = &OpError{"write", fd.net, fd.raddr, oserr}
+ err = &OpError{"write", fd.net, fd.raddr, err}
}
return
}
-func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) {
+func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
if fd == nil || fd.sysfile == nil {
return 0, 0, os.EINVAL
}
@@ -553,73 +545,62 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
defer fd.wio.Unlock()
fd.incref()
defer fd.decref()
- if fd.wdeadline_delta > 0 {
- fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
- } else {
- fd.wdeadline = 0
- }
- var oserr os.Error
for {
- var errno int
- errno = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
- if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
- pollserver.WaitWrite(fd)
- continue
- }
- if errno != 0 {
- oserr = os.Errno(errno)
+ err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
+ if err == syscall.EAGAIN {
+ if fd.wdeadline >= 0 {
+ pollserver.WaitWrite(fd)
+ continue
+ }
+ err = errTimeout
}
break
}
- if oserr == nil {
+ if err == nil {
n = len(p)
oobn = len(oob)
} else {
- err = &OpError{"write", fd.net, fd.raddr, oserr}
+ err = &OpError{"write", fd.net, fd.raddr, err}
}
return
}
-func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) {
+func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err error) {
if fd == nil || fd.sysfile == nil {
return nil, os.EINVAL
}
fd.incref()
defer fd.decref()
- if fd.rdeadline_delta > 0 {
- fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
- } else {
- fd.rdeadline = 0
- }
// See ../syscall/exec.go for description of ForkLock.
// It is okay to hold the lock across syscall.Accept
// because we have put fd.sysfd into non-blocking mode.
- syscall.ForkLock.RLock()
- var s, e int
+ var s int
var rsa syscall.Sockaddr
for {
if fd.closing {
- syscall.ForkLock.RUnlock()
return nil, os.EINVAL
}
- s, rsa, e = syscall.Accept(fd.sysfd)
- if e != syscall.EAGAIN || fd.rdeadline < 0 {
- break
- }
- syscall.ForkLock.RUnlock()
- pollserver.WaitRead(fd)
syscall.ForkLock.RLock()
- }
- if e != 0 {
- syscall.ForkLock.RUnlock()
- return nil, &OpError{"accept", fd.net, fd.laddr, os.Errno(e)}
+ s, rsa, err = syscall.Accept(fd.sysfd)
+ if err != nil {
+ syscall.ForkLock.RUnlock()
+ if err == syscall.EAGAIN {
+ if fd.rdeadline >= 0 {
+ pollserver.WaitRead(fd)
+ continue
+ }
+ err = errTimeout
+ }
+ return nil, &OpError{"accept", fd.net, fd.laddr, err}
+ }
+ break
}
syscall.CloseOnExec(s)
syscall.ForkLock.RUnlock()
- if nfd, err = newFD(s, fd.family, fd.proto, fd.net); err != nil {
+ if nfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil {
syscall.Close(s)
return nil, err
}
@@ -628,20 +609,20 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.
return nfd, nil
}
-func (fd *netFD) dup() (f *os.File, err os.Error) {
- ns, e := syscall.Dup(fd.sysfd)
- if e != 0 {
- return nil, &OpError{"dup", fd.net, fd.laddr, os.Errno(e)}
+func (fd *netFD) dup() (f *os.File, err error) {
+ ns, err := syscall.Dup(fd.sysfd)
+ if err != nil {
+ return nil, &OpError{"dup", fd.net, fd.laddr, err}
}
// We want blocking mode for the new fd, hence the double negative.
- if e = syscall.SetNonblock(ns, false); e != 0 {
- return nil, &OpError{"setnonblock", fd.net, fd.laddr, os.Errno(e)}
+ if err = syscall.SetNonblock(ns, false); err != nil {
+ return nil, &OpError{"setnonblock", fd.net, fd.laddr, err}
}
return os.NewFile(ns, fd.sysfile.Name()), nil
}
-func closesocket(s int) (errno int) {
+func closesocket(s int) error {
return syscall.Close(s)
}
diff --git a/src/pkg/net/fd_darwin.go b/src/pkg/net/fd_darwin.go
index 7e3d549eb..c6db083c4 100644
--- a/src/pkg/net/fd_darwin.go
+++ b/src/pkg/net/fd_darwin.go
@@ -7,6 +7,7 @@
package net
import (
+ "errors"
"os"
"syscall"
)
@@ -21,17 +22,17 @@ type pollster struct {
kbuf [1]syscall.Kevent_t
}
-func newpollster() (p *pollster, err os.Error) {
+func newpollster() (p *pollster, err error) {
p = new(pollster)
- var e int
- if p.kq, e = syscall.Kqueue(); e != 0 {
- return nil, os.NewSyscallError("kqueue", e)
+ if p.kq, err = syscall.Kqueue(); err != nil {
+ return nil, os.NewSyscallError("kqueue", err)
}
+ syscall.CloseOnExec(p.kq)
p.events = p.eventbuf[0:0]
return p, nil
}
-func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) {
+func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
// pollServer is locked.
var kmode int
@@ -51,15 +52,15 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) {
}
syscall.SetKevent(ev, fd, kmode, flags)
- n, e := syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil)
- if e != 0 {
- return false, os.NewSyscallError("kevent", e)
+ n, err := syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil)
+ if err != nil {
+ return false, os.NewSyscallError("kevent", err)
}
if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
- return false, os.NewError("kqueue phase error")
+ return false, errors.New("kqueue phase error")
}
if ev.Data != 0 {
- return false, os.Errno(int(ev.Data))
+ return false, syscall.Errno(ev.Data)
}
return false, nil
}
@@ -81,7 +82,7 @@ func (p *pollster) DelFD(fd int, mode int) {
syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil)
}
-func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) {
+func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
var t *syscall.Timespec
for len(p.events) == 0 {
if nsec > 0 {
@@ -95,11 +96,11 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E
nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[0:], t)
s.Lock()
- if e != 0 {
+ if e != nil {
if e == syscall.EINTR {
continue
}
- return -1, 0, os.NewSyscallError("kevent", e)
+ return -1, 0, os.NewSyscallError("kevent", nil)
}
if nn == 0 {
return -1, 0, nil
@@ -117,4 +118,4 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E
return fd, mode, nil
}
-func (p *pollster) Close() os.Error { return os.NewSyscallError("close", syscall.Close(p.kq)) }
+func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) }
diff --git a/src/pkg/net/fd_freebsd.go b/src/pkg/net/fd_freebsd.go
index e50883e94..31d0744e2 100644
--- a/src/pkg/net/fd_freebsd.go
+++ b/src/pkg/net/fd_freebsd.go
@@ -21,17 +21,17 @@ type pollster struct {
kbuf [1]syscall.Kevent_t
}
-func newpollster() (p *pollster, err os.Error) {
+func newpollster() (p *pollster, err error) {
p = new(pollster)
- var e int
- if p.kq, e = syscall.Kqueue(); e != 0 {
- return nil, os.NewSyscallError("kqueue", e)
+ if p.kq, err = syscall.Kqueue(); err != nil {
+ return nil, os.NewSyscallError("kqueue", err)
}
+ syscall.CloseOnExec(p.kq)
p.events = p.eventbuf[0:0]
return p, nil
}
-func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) {
+func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
// pollServer is locked.
var kmode int
@@ -50,14 +50,14 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) {
syscall.SetKevent(ev, fd, kmode, flags)
n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
- if e != 0 {
+ if e != nil {
return false, os.NewSyscallError("kevent", e)
}
if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
return false, os.NewSyscallError("kqueue phase error", e)
}
if ev.Data != 0 {
- return false, os.Errno(int(ev.Data))
+ return false, syscall.Errno(int(ev.Data))
}
return false, nil
}
@@ -77,7 +77,7 @@ func (p *pollster) DelFD(fd int, mode int) {
syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
}
-func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) {
+func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
var t *syscall.Timespec
for len(p.events) == 0 {
if nsec > 0 {
@@ -91,7 +91,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E
nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t)
s.Lock()
- if e != 0 {
+ if e != nil {
if e == syscall.EINTR {
continue
}
@@ -113,4 +113,4 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E
return fd, mode, nil
}
-func (p *pollster) Close() os.Error { return os.NewSyscallError("close", syscall.Close(p.kq)) }
+func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) }
diff --git a/src/pkg/net/fd_linux.go b/src/pkg/net/fd_linux.go
index 70fc344b2..c8df9c932 100644
--- a/src/pkg/net/fd_linux.go
+++ b/src/pkg/net/fd_linux.go
@@ -33,21 +33,27 @@ type pollster struct {
ctlEvent syscall.EpollEvent
}
-func newpollster() (p *pollster, err os.Error) {
+func newpollster() (p *pollster, err error) {
p = new(pollster)
- var e int
+ var e error
- // The arg to epoll_create is a hint to the kernel
- // about the number of FDs we will care about.
- // We don't know, and since 2.6.8 the kernel ignores it anyhow.
- if p.epfd, e = syscall.EpollCreate(16); e != 0 {
- return nil, os.NewSyscallError("epoll_create", e)
+ if p.epfd, e = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC); e != nil {
+ if e != syscall.ENOSYS {
+ return nil, os.NewSyscallError("epoll_create1", e)
+ }
+ // The arg to epoll_create is a hint to the kernel
+ // about the number of FDs we will care about.
+ // We don't know, and since 2.6.8 the kernel ignores it anyhow.
+ if p.epfd, e = syscall.EpollCreate(16); e != nil {
+ return nil, os.NewSyscallError("epoll_create", e)
+ }
+ syscall.CloseOnExec(p.epfd)
}
p.events = make(map[int]uint32)
return p, nil
}
-func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) {
+func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
// pollServer is locked.
var already bool
@@ -68,7 +74,7 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) {
} else {
op = syscall.EPOLL_CTL_ADD
}
- if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != 0 {
+ if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != nil {
return false, os.NewSyscallError("epoll_ctl", e)
}
p.events[fd] = p.ctlEvent.Events
@@ -97,15 +103,15 @@ func (p *pollster) StopWaiting(fd int, bits uint) {
if int32(events)&^syscall.EPOLLONESHOT != 0 {
p.ctlEvent.Fd = int32(fd)
p.ctlEvent.Events = events
- if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != 0 {
- print("Epoll modify fd=", fd, ": ", os.Errno(e).String(), "\n")
+ if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != nil {
+ print("Epoll modify fd=", fd, ": ", e.Error(), "\n")
}
p.events[fd] = events
} else {
- if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != 0 {
- print("Epoll delete fd=", fd, ": ", os.Errno(e).String(), "\n")
+ if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != nil {
+ print("Epoll delete fd=", fd, ": ", e.Error(), "\n")
}
- p.events[fd] = 0, false
+ delete(p.events, fd)
}
}
@@ -130,7 +136,7 @@ func (p *pollster) DelFD(fd int, mode int) {
}
}
-func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) {
+func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
for len(p.waitEvents) == 0 {
var msec int = -1
if nsec > 0 {
@@ -141,7 +147,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E
n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec)
s.Lock()
- if e != 0 {
+ if e != nil {
if e == syscall.EAGAIN || e == syscall.EINTR {
continue
}
@@ -177,6 +183,6 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E
return fd, 'r', nil
}
-func (p *pollster) Close() os.Error {
+func (p *pollster) Close() error {
return os.NewSyscallError("close", syscall.Close(p.epfd))
}
diff --git a/src/pkg/net/fd_netbsd.go b/src/pkg/net/fd_netbsd.go
new file mode 100644
index 000000000..31d0744e2
--- /dev/null
+++ b/src/pkg/net/fd_netbsd.go
@@ -0,0 +1,116 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Waiting for FDs via kqueue/kevent.
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+type pollster struct {
+ kq int
+ eventbuf [10]syscall.Kevent_t
+ events []syscall.Kevent_t
+
+ // An event buffer for AddFD/DelFD.
+ // Must hold pollServer lock.
+ kbuf [1]syscall.Kevent_t
+}
+
+func newpollster() (p *pollster, err error) {
+ p = new(pollster)
+ if p.kq, err = syscall.Kqueue(); err != nil {
+ return nil, os.NewSyscallError("kqueue", err)
+ }
+ syscall.CloseOnExec(p.kq)
+ p.events = p.eventbuf[0:0]
+ return p, nil
+}
+
+func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
+ // pollServer is locked.
+
+ var kmode int
+ if mode == 'r' {
+ kmode = syscall.EVFILT_READ
+ } else {
+ kmode = syscall.EVFILT_WRITE
+ }
+ ev := &p.kbuf[0]
+ // EV_ADD - add event to kqueue list
+ // EV_ONESHOT - delete the event the first time it triggers
+ flags := syscall.EV_ADD
+ if !repeat {
+ flags |= syscall.EV_ONESHOT
+ }
+ syscall.SetKevent(ev, fd, kmode, flags)
+
+ n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
+ if e != nil {
+ return false, os.NewSyscallError("kevent", e)
+ }
+ if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
+ return false, os.NewSyscallError("kqueue phase error", e)
+ }
+ if ev.Data != 0 {
+ return false, syscall.Errno(int(ev.Data))
+ }
+ return false, nil
+}
+
+func (p *pollster) DelFD(fd int, mode int) {
+ // pollServer is locked.
+
+ var kmode int
+ if mode == 'r' {
+ kmode = syscall.EVFILT_READ
+ } else {
+ kmode = syscall.EVFILT_WRITE
+ }
+ ev := &p.kbuf[0]
+ // EV_DELETE - delete event from kqueue list
+ syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE)
+ syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
+}
+
+func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
+ var t *syscall.Timespec
+ for len(p.events) == 0 {
+ if nsec > 0 {
+ if t == nil {
+ t = new(syscall.Timespec)
+ }
+ *t = syscall.NsecToTimespec(nsec)
+ }
+
+ s.Unlock()
+ nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t)
+ s.Lock()
+
+ if e != nil {
+ if e == syscall.EINTR {
+ continue
+ }
+ return -1, 0, os.NewSyscallError("kevent", e)
+ }
+ if nn == 0 {
+ return -1, 0, nil
+ }
+ p.events = p.eventbuf[0:nn]
+ }
+ ev := &p.events[0]
+ p.events = p.events[1:]
+ fd = int(ev.Ident)
+ if ev.Filter == syscall.EVFILT_READ {
+ mode = 'r'
+ } else {
+ mode = 'w'
+ }
+ return fd, mode, nil
+}
+
+func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) }
diff --git a/src/pkg/net/fd_openbsd.go b/src/pkg/net/fd_openbsd.go
index e50883e94..31d0744e2 100644
--- a/src/pkg/net/fd_openbsd.go
+++ b/src/pkg/net/fd_openbsd.go
@@ -21,17 +21,17 @@ type pollster struct {
kbuf [1]syscall.Kevent_t
}
-func newpollster() (p *pollster, err os.Error) {
+func newpollster() (p *pollster, err error) {
p = new(pollster)
- var e int
- if p.kq, e = syscall.Kqueue(); e != 0 {
- return nil, os.NewSyscallError("kqueue", e)
+ if p.kq, err = syscall.Kqueue(); err != nil {
+ return nil, os.NewSyscallError("kqueue", err)
}
+ syscall.CloseOnExec(p.kq)
p.events = p.eventbuf[0:0]
return p, nil
}
-func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) {
+func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
// pollServer is locked.
var kmode int
@@ -50,14 +50,14 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) {
syscall.SetKevent(ev, fd, kmode, flags)
n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
- if e != 0 {
+ if e != nil {
return false, os.NewSyscallError("kevent", e)
}
if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
return false, os.NewSyscallError("kqueue phase error", e)
}
if ev.Data != 0 {
- return false, os.Errno(int(ev.Data))
+ return false, syscall.Errno(int(ev.Data))
}
return false, nil
}
@@ -77,7 +77,7 @@ func (p *pollster) DelFD(fd int, mode int) {
syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
}
-func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) {
+func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
var t *syscall.Timespec
for len(p.events) == 0 {
if nsec > 0 {
@@ -91,7 +91,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E
nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t)
s.Lock()
- if e != 0 {
+ if e != nil {
if e == syscall.EINTR {
continue
}
@@ -113,4 +113,4 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E
return fd, mode, nil
}
-func (p *pollster) Close() os.Error { return os.NewSyscallError("close", syscall.Close(p.kq)) }
+func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) }
diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go
index b025bddea..f00459f0b 100644
--- a/src/pkg/net/fd_windows.go
+++ b/src/pkg/net/fd_windows.go
@@ -5,6 +5,7 @@
package net
import (
+ "io"
"os"
"runtime"
"sync"
@@ -15,21 +16,21 @@ import (
type InvalidConnError struct{}
-func (e *InvalidConnError) String() string { return "invalid net.Conn" }
+func (e *InvalidConnError) Error() string { return "invalid net.Conn" }
func (e *InvalidConnError) Temporary() bool { return false }
func (e *InvalidConnError) Timeout() bool { return false }
-var initErr os.Error
+var initErr error
func init() {
var d syscall.WSAData
e := syscall.WSAStartup(uint32(0x202), &d)
- if e != 0 {
+ if e != nil {
initErr = os.NewSyscallError("WSAStartup", e)
}
}
-func closesocket(s syscall.Handle) (errno int) {
+func closesocket(s syscall.Handle) (err error) {
return syscall.Closesocket(s)
}
@@ -37,13 +38,13 @@ func closesocket(s syscall.Handle) (errno int) {
type anOpIface interface {
Op() *anOp
Name() string
- Submit() (errno int)
+ Submit() (err error)
}
// IO completion result parameters.
type ioResult struct {
qty uint32
- err int
+ err error
}
// anOp implements functionality common to all io operations.
@@ -53,7 +54,7 @@ type anOp struct {
o syscall.Overlapped
resultc chan ioResult
- errnoc chan int
+ errnoc chan error
fd *netFD
}
@@ -70,7 +71,7 @@ func (o *anOp) Init(fd *netFD, mode int) {
}
o.resultc = fd.resultc[i]
if fd.errnoc[i] == nil {
- fd.errnoc[i] = make(chan int)
+ fd.errnoc[i] = make(chan error)
}
o.errnoc = fd.errnoc[i]
}
@@ -110,14 +111,14 @@ func (s *resultSrv) Run() {
for {
r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE)
switch {
- case r.err == 0:
+ case r.err == nil:
// Dequeued successfully completed io packet.
- case r.err == syscall.WAIT_TIMEOUT && o == nil:
+ case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil:
// Wait has timed out (should not happen now, but might be used in the future).
panic("GetQueuedCompletionStatus timed out")
case o == nil:
// Failed to dequeue anything -> report the error.
- panic("GetQueuedCompletionStatus failed " + syscall.Errstr(r.err))
+ panic("GetQueuedCompletionStatus failed " + r.err.Error())
default:
// Dequeued failed io packet.
}
@@ -149,12 +150,13 @@ func (s *ioSrv) ProcessRemoteIO() {
}
// ExecIO executes a single io operation. It either executes it
-// inline, or, if timeouts are employed, passes the request onto
+// inline, or, if a deadline is employed, passes the request onto
// a special goroutine and waits for completion or cancels request.
-func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err os.Error) {
- var e int
+// deadline is unix nanos.
+func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (n int, err error) {
+ var e error
o := oi.Op()
- if deadline_delta > 0 {
+ if deadline != 0 {
// Send request to a special dedicated thread,
// so it can stop the io with CancelIO later.
s.submchan <- oi
@@ -163,19 +165,25 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err os.Error)
e = oi.Submit()
}
switch e {
- case 0:
+ case nil:
// IO completed immediately, but we need to get our completion message anyway.
case syscall.ERROR_IO_PENDING:
// IO started, and we have to wait for its completion.
default:
- return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(e)}
+ return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, e}
}
// Wait for our request to complete.
var r ioResult
- if deadline_delta > 0 {
+ if deadline != 0 {
+ dt := deadline - time.Now().UnixNano()
+ if dt < 1 {
+ dt = 1
+ }
+ timer := time.NewTimer(time.Duration(dt) * time.Nanosecond)
+ defer timer.Stop()
select {
case r = <-o.resultc:
- case <-time.After(deadline_delta):
+ case <-timer.C:
s.canchan <- oi
<-o.errnoc
r = <-o.resultc
@@ -186,8 +194,8 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err os.Error)
} else {
r = <-o.resultc
}
- if r.err != 0 {
- err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(r.err)}
+ if r.err != nil {
+ err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err}
}
return int(r.qty), err
}
@@ -199,10 +207,10 @@ var onceStartServer sync.Once
func startServer() {
resultsrv = new(resultSrv)
- var errno int
- resultsrv.iocp, errno = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1)
- if errno != 0 {
- panic("CreateIoCompletionPort failed " + syscall.Errstr(errno))
+ var err error
+ resultsrv.iocp, err = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1)
+ if err != nil {
+ panic("CreateIoCompletionPort: " + err.Error())
}
go resultsrv.Run()
@@ -220,43 +228,42 @@ type netFD struct {
closing bool
// immutable until Close
- sysfd syscall.Handle
- family int
- proto int
- net string
- laddr Addr
- raddr Addr
- resultc [2]chan ioResult // read/write completion results
- errnoc [2]chan int // read/write submit or cancel operation errors
+ sysfd syscall.Handle
+ family int
+ sotype int
+ isConnected bool
+ net string
+ laddr Addr
+ raddr Addr
+ resultc [2]chan ioResult // read/write completion results
+ errnoc [2]chan error // read/write submit or cancel operation errors
// owned by client
- rdeadline_delta int64
- rdeadline int64
- rio sync.Mutex
- wdeadline_delta int64
- wdeadline int64
- wio sync.Mutex
+ rdeadline int64
+ rio sync.Mutex
+ wdeadline int64
+ wio sync.Mutex
}
-func allocFD(fd syscall.Handle, family, proto int, net string) (f *netFD) {
+func allocFD(fd syscall.Handle, family, sotype int, net string) (f *netFD) {
f = &netFD{
sysfd: fd,
family: family,
- proto: proto,
+ sotype: sotype,
net: net,
}
runtime.SetFinalizer(f, (*netFD).Close)
return f
}
-func newFD(fd syscall.Handle, family, proto int, net string) (f *netFD, err os.Error) {
+func newFD(fd syscall.Handle, family, proto int, net string) (f *netFD, err error) {
if initErr != nil {
return nil, initErr
}
onceStartServer.Do(startServer)
// Associate our socket with resultsrv.iocp.
- if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != 0 {
- return nil, os.Errno(e)
+ if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != nil {
+ return nil, e
}
return allocFD(fd, family, proto, net), nil
}
@@ -266,12 +273,8 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
fd.raddr = raddr
}
-func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) {
- e := syscall.Connect(fd.sysfd, ra)
- if e != 0 {
- return os.Errno(e)
- }
- return nil
+func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
+ return syscall.Connect(fd.sysfd, ra)
}
// Add a reference to this fd.
@@ -300,7 +303,7 @@ func (fd *netFD) decref() {
fd.sysmu.Unlock()
}
-func (fd *netFD) Close() os.Error {
+func (fd *netFD) Close() error {
if fd == nil || fd.sysfd == syscall.InvalidHandle {
return os.EINVAL
}
@@ -312,13 +315,32 @@ func (fd *netFD) Close() os.Error {
return nil
}
+func (fd *netFD) shutdown(how int) error {
+ if fd == nil || fd.sysfd == syscall.InvalidHandle {
+ return os.EINVAL
+ }
+ err := syscall.Shutdown(fd.sysfd, how)
+ if err != nil {
+ return &OpError{"shutdown", fd.net, fd.laddr, err}
+ }
+ return nil
+}
+
+func (fd *netFD) CloseRead() error {
+ return fd.shutdown(syscall.SHUT_RD)
+}
+
+func (fd *netFD) CloseWrite() error {
+ return fd.shutdown(syscall.SHUT_WR)
+}
+
// Read from network.
type readOp struct {
bufOp
}
-func (o *readOp) Submit() (errno int) {
+func (o *readOp) Submit() (err error) {
var d, f uint32
return syscall.WSARecv(syscall.Handle(o.fd.sysfd), &o.buf, 1, &d, &f, &o.o, nil)
}
@@ -327,7 +349,7 @@ func (o *readOp) Name() string {
return "WSARecv"
}
-func (fd *netFD) Read(buf []byte) (n int, err os.Error) {
+func (fd *netFD) Read(buf []byte) (n int, err error) {
if fd == nil {
return 0, os.EINVAL
}
@@ -340,9 +362,9 @@ func (fd *netFD) Read(buf []byte) (n int, err os.Error) {
}
var o readOp
o.Init(fd, buf, 'r')
- n, err = iosrv.ExecIO(&o, fd.rdeadline_delta)
+ n, err = iosrv.ExecIO(&o, fd.rdeadline)
if err == nil && n == 0 {
- err = os.EOF
+ err = io.EOF
}
return
}
@@ -355,7 +377,7 @@ type readFromOp struct {
rsan int32
}
-func (o *readFromOp) Submit() (errno int) {
+func (o *readFromOp) Submit() (err error) {
var d, f uint32
return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &d, &f, &o.rsa, &o.rsan, &o.o, nil)
}
@@ -364,7 +386,7 @@ func (o *readFromOp) Name() string {
return "WSARecvFrom"
}
-func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err os.Error) {
+func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) {
if fd == nil {
return 0, nil, os.EINVAL
}
@@ -381,7 +403,7 @@ func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err os.Error)
var o readFromOp
o.Init(fd, buf, 'r')
o.rsan = int32(unsafe.Sizeof(o.rsa))
- n, err = iosrv.ExecIO(&o, fd.rdeadline_delta)
+ n, err = iosrv.ExecIO(&o, fd.rdeadline)
if err != nil {
return 0, nil, err
}
@@ -395,7 +417,7 @@ type writeOp struct {
bufOp
}
-func (o *writeOp) Submit() (errno int) {
+func (o *writeOp) Submit() (err error) {
var d uint32
return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &d, 0, &o.o, nil)
}
@@ -404,7 +426,7 @@ func (o *writeOp) Name() string {
return "WSASend"
}
-func (fd *netFD) Write(buf []byte) (n int, err os.Error) {
+func (fd *netFD) Write(buf []byte) (n int, err error) {
if fd == nil {
return 0, os.EINVAL
}
@@ -417,7 +439,7 @@ func (fd *netFD) Write(buf []byte) (n int, err os.Error) {
}
var o writeOp
o.Init(fd, buf, 'w')
- return iosrv.ExecIO(&o, fd.wdeadline_delta)
+ return iosrv.ExecIO(&o, fd.wdeadline)
}
// WriteTo to network.
@@ -427,7 +449,7 @@ type writeToOp struct {
sa syscall.Sockaddr
}
-func (o *writeToOp) Submit() (errno int) {
+func (o *writeToOp) Submit() (err error) {
var d uint32
return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &d, 0, o.sa, &o.o, nil)
}
@@ -436,7 +458,7 @@ func (o *writeToOp) Name() string {
return "WSASendto"
}
-func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (n int, err os.Error) {
+func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (n int, err error) {
if fd == nil {
return 0, os.EINVAL
}
@@ -453,7 +475,7 @@ func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (n int, err os.Error)
var o writeToOp
o.Init(fd, buf, 'w')
o.sa = sa
- return iosrv.ExecIO(&o, fd.wdeadline_delta)
+ return iosrv.ExecIO(&o, fd.wdeadline)
}
// Accept new network connections.
@@ -464,7 +486,7 @@ type acceptOp struct {
attrs [2]syscall.RawSockaddrAny // space for local and remote address only
}
-func (o *acceptOp) Submit() (errno int) {
+func (o *acceptOp) Submit() (err error) {
var d uint32
l := uint32(unsafe.Sizeof(o.attrs[0]))
return syscall.AcceptEx(o.fd.sysfd, o.newsock,
@@ -475,7 +497,7 @@ func (o *acceptOp) Name() string {
return "AcceptEx"
}
-func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) {
+func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err error) {
if fd == nil || fd.sysfd == syscall.InvalidHandle {
return nil, os.EINVAL
}
@@ -485,18 +507,18 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.
// Get new socket.
// See ../syscall/exec.go for description of ForkLock.
syscall.ForkLock.RLock()
- s, e := syscall.Socket(fd.family, fd.proto, 0)
- if e != 0 {
+ s, e := syscall.Socket(fd.family, fd.sotype, 0)
+ if e != nil {
syscall.ForkLock.RUnlock()
- return nil, os.Errno(e)
+ return nil, e
}
syscall.CloseOnExec(s)
syscall.ForkLock.RUnlock()
// Associate our new socket with IOCP.
onceStartServer.Do(startServer)
- if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != 0 {
- return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, os.Errno(e)}
+ if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != nil {
+ return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, e}
}
// Submit accept request.
@@ -511,9 +533,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.
// Inherit properties of the listening socket.
e = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
- if e != 0 {
+ if e != nil {
closesocket(s)
- return nil, err
+ return nil, e
}
// Get local and peer addr out of AcceptEx buffer.
@@ -525,22 +547,22 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.
lsa, _ := lrsa.Sockaddr()
rsa, _ := rrsa.Sockaddr()
- nfd = allocFD(s, fd.family, fd.proto, fd.net)
+ nfd = allocFD(s, fd.family, fd.sotype, fd.net)
nfd.setAddr(toAddr(lsa), toAddr(rsa))
return nfd, nil
}
// Unimplemented functions.
-func (fd *netFD) dup() (f *os.File, err os.Error) {
+func (fd *netFD) dup() (f *os.File, err error) {
// TODO: Implement this
return nil, os.NewSyscallError("dup", syscall.EWINDOWS)
}
-func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) {
+func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
return 0, 0, 0, nil, os.EAFNOSUPPORT
}
-func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) {
+func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
return 0, 0, os.EAFNOSUPPORT
}
diff --git a/src/pkg/net/file.go b/src/pkg/net/file.go
index d8528e41b..4ac280bd1 100644
--- a/src/pkg/net/file.go
+++ b/src/pkg/net/file.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd
+// +build darwin freebsd linux netbsd openbsd
package net
@@ -11,17 +11,18 @@ import (
"syscall"
)
-func newFileFD(f *os.File) (nfd *netFD, err os.Error) {
+func newFileFD(f *os.File) (nfd *netFD, err error) {
fd, errno := syscall.Dup(f.Fd())
- if errno != 0 {
+ if errno != nil {
return nil, os.NewSyscallError("dup", errno)
}
proto, errno := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
- if errno != 0 {
+ if errno != nil {
return nil, os.NewSyscallError("getsockopt", errno)
}
+ family := syscall.AF_UNSPEC
toAddr := sockaddrToTCP
sa, _ := syscall.Getsockname(fd)
switch sa.(type) {
@@ -29,18 +30,21 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) {
closesocket(fd)
return nil, os.EINVAL
case *syscall.SockaddrInet4:
+ family = syscall.AF_INET
if proto == syscall.SOCK_DGRAM {
toAddr = sockaddrToUDP
} else if proto == syscall.SOCK_RAW {
toAddr = sockaddrToIP
}
case *syscall.SockaddrInet6:
+ family = syscall.AF_INET6
if proto == syscall.SOCK_DGRAM {
toAddr = sockaddrToUDP
} else if proto == syscall.SOCK_RAW {
toAddr = sockaddrToIP
}
case *syscall.SockaddrUnix:
+ family = syscall.AF_UNIX
toAddr = sockaddrToUnix
if proto == syscall.SOCK_DGRAM {
toAddr = sockaddrToUnixgram
@@ -52,7 +56,7 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) {
sa, _ = syscall.Getpeername(fd)
raddr := toAddr(sa)
- if nfd, err = newFD(fd, 0, proto, laddr.Network()); err != nil {
+ if nfd, err = newFD(fd, family, proto, laddr.Network()); err != nil {
return nil, err
}
nfd.setAddr(laddr, raddr)
@@ -63,7 +67,7 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) {
// the open file f. It is the caller's responsibility to close f when
// finished. Closing c does not affect f, and closing f does not
// affect c.
-func FileConn(f *os.File) (c Conn, err os.Error) {
+func FileConn(f *os.File) (c Conn, err error) {
fd, err := newFileFD(f)
if err != nil {
return nil, err
@@ -86,7 +90,7 @@ func FileConn(f *os.File) (c Conn, err os.Error) {
// to the open file f. It is the caller's responsibility to close l
// when finished. Closing c does not affect l, and closing l does not
// affect c.
-func FileListener(f *os.File) (l Listener, err os.Error) {
+func FileListener(f *os.File) (l Listener, err error) {
fd, err := newFileFD(f)
if err != nil {
return nil, err
@@ -105,7 +109,7 @@ func FileListener(f *os.File) (l Listener, err os.Error) {
// corresponding to the open file f. It is the caller's
// responsibility to close f when finished. Closing c does not affect
// f, and closing f does not affect c.
-func FilePacketConn(f *os.File) (c PacketConn, err os.Error) {
+func FilePacketConn(f *os.File) (c PacketConn, err error) {
fd, err := newFileFD(f)
if err != nil {
return nil, err
diff --git a/src/pkg/net/file_plan9.go b/src/pkg/net/file_plan9.go
index a07e74331..06d7cc898 100644
--- a/src/pkg/net/file_plan9.go
+++ b/src/pkg/net/file_plan9.go
@@ -12,7 +12,7 @@ import (
// the open file f. It is the caller's responsibility to close f when
// finished. Closing c does not affect f, and closing f does not
// affect c.
-func FileConn(f *os.File) (c Conn, err os.Error) {
+func FileConn(f *os.File) (c Conn, err error) {
return nil, os.EPLAN9
}
@@ -20,7 +20,7 @@ func FileConn(f *os.File) (c Conn, err os.Error) {
// to the open file f. It is the caller's responsibility to close l
// when finished. Closing c does not affect l, and closing l does not
// affect c.
-func FileListener(f *os.File) (l Listener, err os.Error) {
+func FileListener(f *os.File) (l Listener, err error) {
return nil, os.EPLAN9
}
@@ -28,6 +28,6 @@ func FileListener(f *os.File) (l Listener, err os.Error) {
// corresponding to the open file f. It is the caller's
// responsibility to close f when finished. Closing c does not affect
// f, and closing f does not affect c.
-func FilePacketConn(f *os.File) (c PacketConn, err os.Error) {
+func FilePacketConn(f *os.File) (c PacketConn, err error) {
return nil, os.EPLAN9
}
diff --git a/src/pkg/net/file_test.go b/src/pkg/net/file_test.go
index 9a8c2dcbc..868388efa 100644
--- a/src/pkg/net/file_test.go
+++ b/src/pkg/net/file_test.go
@@ -8,23 +8,22 @@ import (
"os"
"reflect"
"runtime"
- "syscall"
"testing"
)
type listenerFile interface {
Listener
- File() (f *os.File, err os.Error)
+ File() (f *os.File, err error)
}
type packetConnFile interface {
PacketConn
- File() (f *os.File, err os.Error)
+ File() (f *os.File, err error)
}
type connFile interface {
Conn
- File() (f *os.File, err os.Error)
+ File() (f *os.File, err error)
}
func testFileListener(t *testing.T, net, laddr string) {
@@ -67,13 +66,13 @@ func TestFileListener(t *testing.T) {
testFileListener(t, "tcp", "127.0.0.1")
testFileListener(t, "tcp", "[::ffff:127.0.0.1]")
}
- if syscall.OS == "linux" {
+ if runtime.GOOS == "linux" {
testFileListener(t, "unix", "@gotest/net")
testFileListener(t, "unixpacket", "@gotest/net")
}
}
-func testFilePacketConn(t *testing.T, pcf packetConnFile) {
+func testFilePacketConn(t *testing.T, pcf packetConnFile, listen bool) {
f, err := pcf.File()
if err != nil {
t.Fatalf("File failed: %v", err)
@@ -85,6 +84,11 @@ func testFilePacketConn(t *testing.T, pcf packetConnFile) {
if !reflect.DeepEqual(pcf.LocalAddr(), c.LocalAddr()) {
t.Fatalf("LocalAddrs not equal: %#v != %#v", pcf.LocalAddr(), c.LocalAddr())
}
+ if listen {
+ if _, err := c.WriteTo([]byte{}, c.LocalAddr()); err != nil {
+ t.Fatalf("WriteTo failed: %v", err)
+ }
+ }
if err := c.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
@@ -98,7 +102,7 @@ func testFilePacketConnListen(t *testing.T, net, laddr string) {
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
- testFilePacketConn(t, l.(packetConnFile))
+ testFilePacketConn(t, l.(packetConnFile), true)
if err := l.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
@@ -109,7 +113,7 @@ func testFilePacketConnDial(t *testing.T, net, raddr string) {
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
- testFilePacketConn(t, c.(packetConnFile))
+ testFilePacketConn(t, c.(packetConnFile), false)
if err := c.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
@@ -127,7 +131,7 @@ func TestFilePacketConn(t *testing.T) {
if supportsIPv6 && supportsIPv4map {
testFilePacketConnDial(t, "udp", "[::ffff:127.0.0.1]:12345")
}
- if syscall.OS == "linux" {
+ if runtime.GOOS == "linux" {
testFilePacketConnListen(t, "unixgram", "@gotest1/net")
}
}
diff --git a/src/pkg/net/file_windows.go b/src/pkg/net/file_windows.go
index 94aa58375..c50c32e21 100644
--- a/src/pkg/net/file_windows.go
+++ b/src/pkg/net/file_windows.go
@@ -9,17 +9,17 @@ import (
"syscall"
)
-func FileConn(f *os.File) (c Conn, err os.Error) {
+func FileConn(f *os.File) (c Conn, err error) {
// TODO: Implement this
return nil, os.NewSyscallError("FileConn", syscall.EWINDOWS)
}
-func FileListener(f *os.File) (l Listener, err os.Error) {
+func FileListener(f *os.File) (l Listener, err error) {
// TODO: Implement this
return nil, os.NewSyscallError("FileListener", syscall.EWINDOWS)
}
-func FilePacketConn(f *os.File) (c PacketConn, err os.Error) {
+func FilePacketConn(f *os.File) (c PacketConn, err error) {
// TODO: Implement this
return nil, os.NewSyscallError("FilePacketConn", syscall.EWINDOWS)
}
diff --git a/src/pkg/net/hosts.go b/src/pkg/net/hosts.go
index d75e9e038..e6674ba34 100644
--- a/src/pkg/net/hosts.go
+++ b/src/pkg/net/hosts.go
@@ -7,11 +7,11 @@
package net
import (
- "os"
"sync"
+ "time"
)
-const cacheMaxAge = int64(300) // 5 minutes.
+const cacheMaxAge = 5 * time.Minute
// hostsPath points to the file with static IP/address entries.
var hostsPath = "/etc/hosts"
@@ -21,14 +21,14 @@ var hosts struct {
sync.Mutex
byName map[string][]string
byAddr map[string][]string
- time int64
+ expire time.Time
path string
}
func readHosts() {
- now, _, _ := os.Time()
+ now := time.Now()
hp := hostsPath
- if len(hosts.byName) == 0 || hosts.time+cacheMaxAge <= now || hosts.path != hp {
+ if len(hosts.byName) == 0 || now.After(hosts.expire) || hosts.path != hp {
hs := make(map[string][]string)
is := make(map[string][]string)
var file *file
@@ -51,7 +51,7 @@ func readHosts() {
}
}
// Update the data cache.
- hosts.time, _, _ = os.Time()
+ hosts.expire = time.Now().Add(cacheMaxAge)
hosts.path = hp
hosts.byName = hs
hosts.byAddr = is
diff --git a/src/pkg/net/http/Makefile b/src/pkg/net/http/Makefile
new file mode 100644
index 000000000..5c351b0c4
--- /dev/null
+++ b/src/pkg/net/http/Makefile
@@ -0,0 +1,25 @@
+# Copyright 2009 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../Make.inc
+
+TARG=net/http
+GOFILES=\
+ chunked.go\
+ client.go\
+ cookie.go\
+ filetransport.go\
+ fs.go\
+ header.go\
+ jar.go\
+ lex.go\
+ request.go\
+ response.go\
+ server.go\
+ sniff.go\
+ status.go\
+ transfer.go\
+ transport.go\
+
+include ../../../Make.pkg
diff --git a/src/pkg/net/http/cgi/Makefile b/src/pkg/net/http/cgi/Makefile
new file mode 100644
index 000000000..0d6be0180
--- /dev/null
+++ b/src/pkg/net/http/cgi/Makefile
@@ -0,0 +1,12 @@
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../../Make.inc
+
+TARG=net/http/cgi
+GOFILES=\
+ child.go\
+ host.go\
+
+include ../../../../Make.pkg
diff --git a/src/pkg/net/http/cgi/child.go b/src/pkg/net/http/cgi/child.go
new file mode 100644
index 000000000..e6c3ef911
--- /dev/null
+++ b/src/pkg/net/http/cgi/child.go
@@ -0,0 +1,192 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements CGI from the perspective of a child
+// process.
+
+package cgi
+
+import (
+ "bufio"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "strconv"
+ "strings"
+)
+
+// Request returns the HTTP request as represented in the current
+// environment. This assumes the current program is being run
+// by a web server in a CGI environment.
+// The returned Request's Body is populated, if applicable.
+func Request() (*http.Request, error) {
+ r, err := RequestFromMap(envMap(os.Environ()))
+ if err != nil {
+ return nil, err
+ }
+ if r.ContentLength > 0 {
+ r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, r.ContentLength))
+ }
+ return r, nil
+}
+
+func envMap(env []string) map[string]string {
+ m := make(map[string]string)
+ for _, kv := range env {
+ if idx := strings.Index(kv, "="); idx != -1 {
+ m[kv[:idx]] = kv[idx+1:]
+ }
+ }
+ return m
+}
+
+// RequestFromMap creates an http.Request from CGI variables.
+// The returned Request's Body field is not populated.
+func RequestFromMap(params map[string]string) (*http.Request, error) {
+ r := new(http.Request)
+ r.Method = params["REQUEST_METHOD"]
+ if r.Method == "" {
+ return nil, errors.New("cgi: no REQUEST_METHOD in environment")
+ }
+
+ r.Proto = params["SERVER_PROTOCOL"]
+ var ok bool
+ r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto)
+ if !ok {
+ return nil, errors.New("cgi: invalid SERVER_PROTOCOL version")
+ }
+
+ r.Close = true
+ r.Trailer = http.Header{}
+ r.Header = http.Header{}
+
+ r.Host = params["HTTP_HOST"]
+
+ if lenstr := params["CONTENT_LENGTH"]; lenstr != "" {
+ clen, err := strconv.ParseInt(lenstr, 10, 64)
+ if err != nil {
+ return nil, errors.New("cgi: bad CONTENT_LENGTH in environment: " + lenstr)
+ }
+ r.ContentLength = clen
+ }
+
+ if ct := params["CONTENT_TYPE"]; ct != "" {
+ r.Header.Set("Content-Type", ct)
+ }
+
+ // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers
+ for k, v := range params {
+ if !strings.HasPrefix(k, "HTTP_") || k == "HTTP_HOST" {
+ continue
+ }
+ r.Header.Add(strings.Replace(k[5:], "_", "-", -1), v)
+ }
+
+ // TODO: cookies. parsing them isn't exported, though.
+
+ if r.Host != "" {
+ // Hostname is provided, so we can reasonably construct a URL,
+ // even if we have to assume 'http' for the scheme.
+ rawurl := "http://" + r.Host + params["REQUEST_URI"]
+ url, err := url.Parse(rawurl)
+ if err != nil {
+ return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl)
+ }
+ r.URL = url
+ }
+ // Fallback logic if we don't have a Host header or the URL
+ // failed to parse
+ if r.URL == nil {
+ uriStr := params["REQUEST_URI"]
+ url, err := url.Parse(uriStr)
+ if err != nil {
+ return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr)
+ }
+ r.URL = url
+ }
+
+ // There's apparently a de-facto standard for this.
+ // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
+ if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" {
+ r.TLS = &tls.ConnectionState{HandshakeComplete: true}
+ }
+
+ // Request.RemoteAddr has its port set by Go's standard http
+ // server, so we do here too. We don't have one, though, so we
+ // use a dummy one.
+ r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], "0")
+
+ return r, nil
+}
+
+// Serve executes the provided Handler on the currently active CGI
+// request, if any. If there's no current CGI environment
+// an error is returned. The provided handler may be nil to use
+// http.DefaultServeMux.
+func Serve(handler http.Handler) error {
+ req, err := Request()
+ if err != nil {
+ return err
+ }
+ if handler == nil {
+ handler = http.DefaultServeMux
+ }
+ rw := &response{
+ req: req,
+ header: make(http.Header),
+ bufw: bufio.NewWriter(os.Stdout),
+ }
+ handler.ServeHTTP(rw, req)
+ if err = rw.bufw.Flush(); err != nil {
+ return err
+ }
+ return nil
+}
+
+type response struct {
+ req *http.Request
+ header http.Header
+ bufw *bufio.Writer
+ headerSent bool
+}
+
+func (r *response) Flush() {
+ r.bufw.Flush()
+}
+
+func (r *response) Header() http.Header {
+ return r.header
+}
+
+func (r *response) Write(p []byte) (n int, err error) {
+ if !r.headerSent {
+ r.WriteHeader(http.StatusOK)
+ }
+ return r.bufw.Write(p)
+}
+
+func (r *response) WriteHeader(code int) {
+ if r.headerSent {
+ // Note: explicitly using Stderr, as Stdout is our HTTP output.
+ fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL)
+ return
+ }
+ r.headerSent = true
+ fmt.Fprintf(r.bufw, "Status: %d %s\r\n", code, http.StatusText(code))
+
+ // Set a default Content-Type
+ if _, hasType := r.header["Content-Type"]; !hasType {
+ r.header.Add("Content-Type", "text/html; charset=utf-8")
+ }
+
+ r.header.Write(r.bufw)
+ r.bufw.WriteString("\r\n")
+ r.bufw.Flush()
+}
diff --git a/src/pkg/net/http/cgi/child_test.go b/src/pkg/net/http/cgi/child_test.go
new file mode 100644
index 000000000..ec53ab851
--- /dev/null
+++ b/src/pkg/net/http/cgi/child_test.go
@@ -0,0 +1,87 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for CGI (the child process perspective)
+
+package cgi
+
+import (
+ "testing"
+)
+
+func TestRequest(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "REQUEST_METHOD": "GET",
+ "HTTP_HOST": "example.com",
+ "HTTP_REFERER": "elsewhere",
+ "HTTP_USER_AGENT": "goclient",
+ "HTTP_FOO_BAR": "baz",
+ "REQUEST_URI": "/path?a=b",
+ "CONTENT_LENGTH": "123",
+ "CONTENT_TYPE": "text/xml",
+ "HTTPS": "1",
+ "REMOTE_ADDR": "5.6.7.8",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if g, e := req.UserAgent(), "goclient"; e != g {
+ t.Errorf("expected UserAgent %q; got %q", e, g)
+ }
+ if g, e := req.Method, "GET"; e != g {
+ t.Errorf("expected Method %q; got %q", e, g)
+ }
+ if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g {
+ t.Errorf("expected Content-Type %q; got %q", e, g)
+ }
+ if g, e := req.ContentLength, int64(123); e != g {
+ t.Errorf("expected ContentLength %d; got %d", e, g)
+ }
+ if g, e := req.Referer(), "elsewhere"; e != g {
+ t.Errorf("expected Referer %q; got %q", e, g)
+ }
+ if req.Header == nil {
+ t.Fatalf("unexpected nil Header")
+ }
+ if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g {
+ t.Errorf("expected Foo-Bar %q; got %q", e, g)
+ }
+ if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g {
+ t.Errorf("expected URL %q; got %q", e, g)
+ }
+ if g, e := req.FormValue("a"), "b"; e != g {
+ t.Errorf("expected FormValue(a) %q; got %q", e, g)
+ }
+ if req.Trailer == nil {
+ t.Errorf("unexpected nil Trailer")
+ }
+ if req.TLS == nil {
+ t.Errorf("expected non-nil TLS")
+ }
+ if e, g := "5.6.7.8:0", req.RemoteAddr; e != g {
+ t.Errorf("RemoteAddr: got %q; want %q", g, e)
+ }
+}
+
+func TestRequestWithoutHost(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "HTTP_HOST": "",
+ "REQUEST_METHOD": "GET",
+ "REQUEST_URI": "/path?a=b",
+ "CONTENT_LENGTH": "123",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if req.URL == nil {
+ t.Fatalf("unexpected nil URL")
+ }
+ if g, e := req.URL.String(), "/path?a=b"; e != g {
+ t.Errorf("expected URL %q; got %q", e, g)
+ }
+}
diff --git a/src/pkg/net/http/cgi/host.go b/src/pkg/net/http/cgi/host.go
new file mode 100644
index 000000000..73a9b6ea6
--- /dev/null
+++ b/src/pkg/net/http/cgi/host.go
@@ -0,0 +1,350 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements the host side of CGI (being the webserver
+// parent process).
+
+// Package cgi implements CGI (Common Gateway Interface) as specified
+// in RFC 3875.
+//
+// Note that using CGI means starting a new process to handle each
+// request, which is typically less efficient than using a
+// long-running server. This package is intended primarily for
+// compatibility with existing systems.
+package cgi
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "strconv"
+ "strings"
+)
+
+var trailingPort = regexp.MustCompile(`:([0-9]+)$`)
+
+var osDefaultInheritEnv = map[string][]string{
+ "darwin": {"DYLD_LIBRARY_PATH"},
+ "freebsd": {"LD_LIBRARY_PATH"},
+ "hpux": {"LD_LIBRARY_PATH", "SHLIB_PATH"},
+ "irix": {"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"},
+ "linux": {"LD_LIBRARY_PATH"},
+ "openbsd": {"LD_LIBRARY_PATH"},
+ "solaris": {"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"},
+ "windows": {"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"},
+}
+
+// Handler runs an executable in a subprocess with a CGI environment.
+type Handler struct {
+ Path string // path to the CGI executable
+ Root string // root URI prefix of handler or empty for "/"
+
+ // Dir specifies the CGI executable's working directory.
+ // If Dir is empty, the base directory of Path is used.
+ // If Path has no base directory, the current working
+ // directory is used.
+ Dir string
+
+ Env []string // extra environment variables to set, if any, as "key=value"
+ InheritEnv []string // environment variables to inherit from host, as "key"
+ Logger *log.Logger // optional log for errors or nil to use log.Print
+ Args []string // optional arguments to pass to child process
+
+ // PathLocationHandler specifies the root http Handler that
+ // should handle internal redirects when the CGI process
+ // returns a Location header value starting with a "/", as
+ // specified in RFC 3875 § 6.3.2. This will likely be
+ // http.DefaultServeMux.
+ //
+ // If nil, a CGI response with a local URI path is instead sent
+ // back to the client and not redirected internally.
+ PathLocationHandler http.Handler
+}
+
+// removeLeadingDuplicates remove leading duplicate in environments.
+// It's possible to override environment like following.
+// cgi.Handler{
+// ...
+// Env: []string{"SCRIPT_FILENAME=foo.php"},
+// }
+func removeLeadingDuplicates(env []string) (ret []string) {
+ n := len(env)
+ for i := 0; i < n; i++ {
+ e := env[i]
+ s := strings.SplitN(e, "=", 2)[0]
+ found := false
+ for j := i + 1; j < n; j++ {
+ if s == strings.SplitN(env[j], "=", 2)[0] {
+ found = true
+ break
+ }
+ }
+ if !found {
+ ret = append(ret, e)
+ }
+ }
+ return
+}
+
+func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ root := h.Root
+ if root == "" {
+ root = "/"
+ }
+
+ if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
+ rw.WriteHeader(http.StatusBadRequest)
+ rw.Write([]byte("Chunked request bodies are not supported by CGI."))
+ return
+ }
+
+ pathInfo := req.URL.Path
+ if root != "/" && strings.HasPrefix(pathInfo, root) {
+ pathInfo = pathInfo[len(root):]
+ }
+
+ port := "80"
+ if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 {
+ port = matches[1]
+ }
+
+ env := []string{
+ "SERVER_SOFTWARE=go",
+ "SERVER_NAME=" + req.Host,
+ "SERVER_PROTOCOL=HTTP/1.1",
+ "HTTP_HOST=" + req.Host,
+ "GATEWAY_INTERFACE=CGI/1.1",
+ "REQUEST_METHOD=" + req.Method,
+ "QUERY_STRING=" + req.URL.RawQuery,
+ "REQUEST_URI=" + req.URL.RequestURI(),
+ "PATH_INFO=" + pathInfo,
+ "SCRIPT_NAME=" + root,
+ "SCRIPT_FILENAME=" + h.Path,
+ "REMOTE_ADDR=" + req.RemoteAddr,
+ "REMOTE_HOST=" + req.RemoteAddr,
+ "SERVER_PORT=" + port,
+ }
+
+ if req.TLS != nil {
+ env = append(env, "HTTPS=on")
+ }
+
+ for k, v := range req.Header {
+ k = strings.Map(upperCaseAndUnderscore, k)
+ joinStr := ", "
+ if k == "COOKIE" {
+ joinStr = "; "
+ }
+ env = append(env, "HTTP_"+k+"="+strings.Join(v, joinStr))
+ }
+
+ if req.ContentLength > 0 {
+ env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength))
+ }
+ if ctype := req.Header.Get("Content-Type"); ctype != "" {
+ env = append(env, "CONTENT_TYPE="+ctype)
+ }
+
+ if h.Env != nil {
+ env = append(env, h.Env...)
+ }
+
+ envPath := os.Getenv("PATH")
+ if envPath == "" {
+ envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin"
+ }
+ env = append(env, "PATH="+envPath)
+
+ for _, e := range h.InheritEnv {
+ if v := os.Getenv(e); v != "" {
+ env = append(env, e+"="+v)
+ }
+ }
+
+ for _, e := range osDefaultInheritEnv[runtime.GOOS] {
+ if v := os.Getenv(e); v != "" {
+ env = append(env, e+"="+v)
+ }
+ }
+
+ env = removeLeadingDuplicates(env)
+
+ var cwd, path string
+ if h.Dir != "" {
+ path = h.Path
+ cwd = h.Dir
+ } else {
+ cwd, path = filepath.Split(h.Path)
+ }
+ if cwd == "" {
+ cwd = "."
+ }
+
+ internalError := func(err error) {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("CGI error: %v", err)
+ }
+
+ cmd := &exec.Cmd{
+ Path: path,
+ Args: append([]string{h.Path}, h.Args...),
+ Dir: cwd,
+ Env: env,
+ Stderr: os.Stderr, // for now
+ }
+ if req.ContentLength != 0 {
+ cmd.Stdin = req.Body
+ }
+ stdoutRead, err := cmd.StdoutPipe()
+ if err != nil {
+ internalError(err)
+ return
+ }
+
+ err = cmd.Start()
+ if err != nil {
+ internalError(err)
+ return
+ }
+ defer cmd.Wait()
+ defer stdoutRead.Close()
+
+ linebody, _ := bufio.NewReaderSize(stdoutRead, 1024)
+ headers := make(http.Header)
+ statusCode := 0
+ for {
+ line, isPrefix, err := linebody.ReadLine()
+ if isPrefix {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: long header line from subprocess.")
+ return
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: error reading headers: %v", err)
+ return
+ }
+ if len(line) == 0 {
+ break
+ }
+ parts := strings.SplitN(string(line), ":", 2)
+ if len(parts) < 2 {
+ h.printf("cgi: bogus header line: %s", string(line))
+ continue
+ }
+ header, val := parts[0], parts[1]
+ header = strings.TrimSpace(header)
+ val = strings.TrimSpace(val)
+ switch {
+ case header == "Status":
+ if len(val) < 3 {
+ h.printf("cgi: bogus status (short): %q", val)
+ return
+ }
+ code, err := strconv.Atoi(val[0:3])
+ if err != nil {
+ h.printf("cgi: bogus status: %q", val)
+ h.printf("cgi: line was %q", line)
+ return
+ }
+ statusCode = code
+ default:
+ headers.Add(header, val)
+ }
+ }
+
+ if loc := headers.Get("Location"); loc != "" {
+ if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil {
+ h.handleInternalRedirect(rw, req, loc)
+ return
+ }
+ if statusCode == 0 {
+ statusCode = http.StatusFound
+ }
+ }
+
+ if statusCode == 0 {
+ statusCode = http.StatusOK
+ }
+
+ // Copy headers to rw's headers, after we've decided not to
+ // go into handleInternalRedirect, which won't want its rw
+ // headers to have been touched.
+ for k, vv := range headers {
+ for _, v := range vv {
+ rw.Header().Add(k, v)
+ }
+ }
+
+ rw.WriteHeader(statusCode)
+
+ _, err = io.Copy(rw, linebody)
+ if err != nil {
+ h.printf("cgi: copy error: %v", err)
+ }
+}
+
+func (h *Handler) printf(format string, v ...interface{}) {
+ if h.Logger != nil {
+ h.Logger.Printf(format, v...)
+ } else {
+ log.Printf(format, v...)
+ }
+}
+
+func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) {
+ url, err := req.URL.Parse(path)
+ if err != nil {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: error resolving local URI path %q: %v", path, err)
+ return
+ }
+ // TODO: RFC 3875 isn't clear if only GET is supported, but it
+ // suggests so: "Note that any message-body attached to the
+ // request (such as for a POST request) may not be available
+ // to the resource that is the target of the redirect." We
+ // should do some tests against Apache to see how it handles
+ // POST, HEAD, etc. Does the internal redirect get the same
+ // method or just GET? What about incoming headers?
+ // (e.g. Cookies) Which headers, if any, are copied into the
+ // second request?
+ newReq := &http.Request{
+ Method: "GET",
+ URL: url,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(http.Header),
+ Host: url.Host,
+ RemoteAddr: req.RemoteAddr,
+ TLS: req.TLS,
+ }
+ h.PathLocationHandler.ServeHTTP(rw, newReq)
+}
+
+func upperCaseAndUnderscore(r rune) rune {
+ switch {
+ case r >= 'a' && r <= 'z':
+ return r - ('a' - 'A')
+ case r == '-':
+ return '_'
+ case r == '=':
+ // Maybe not part of the CGI 'spec' but would mess up
+ // the environment in any case, as Go represents the
+ // environment as a slice of "key=value" strings.
+ return '_'
+ }
+ // TODO: other transformations in spec or practice?
+ return r
+}
diff --git a/src/pkg/net/http/cgi/host_test.go b/src/pkg/net/http/cgi/host_test.go
new file mode 100644
index 000000000..9ef80ea5e
--- /dev/null
+++ b/src/pkg/net/http/cgi/host_test.go
@@ -0,0 +1,477 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for package cgi
+
+package cgi
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+func newRequest(httpreq string) *http.Request {
+ buf := bufio.NewReader(strings.NewReader(httpreq))
+ req, err := http.ReadRequest(buf)
+ if err != nil {
+ panic("cgi: bogus http request in test: " + httpreq)
+ }
+ req.RemoteAddr = "1.2.3.4"
+ return req
+}
+
+func runCgiTest(t *testing.T, h *Handler, httpreq string, expectedMap map[string]string) *httptest.ResponseRecorder {
+ rw := httptest.NewRecorder()
+ req := newRequest(httpreq)
+ h.ServeHTTP(rw, req)
+
+ // Make a map to hold the test map that the CGI returns.
+ m := make(map[string]string)
+ linesRead := 0
+readlines:
+ for {
+ line, err := rw.Body.ReadString('\n')
+ switch {
+ case err == io.EOF:
+ break readlines
+ case err != nil:
+ t.Fatalf("unexpected error reading from CGI: %v", err)
+ }
+ linesRead++
+ trimmedLine := strings.TrimRight(line, "\r\n")
+ split := strings.SplitN(trimmedLine, "=", 2)
+ if len(split) != 2 {
+ t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v",
+ len(split), linesRead, line, m)
+ }
+ m[split[0]] = split[1]
+ }
+
+ for key, expected := range expectedMap {
+ if got := m[key]; got != expected {
+ t.Errorf("for key %q got %q; expected %q", key, got, expected)
+ }
+ }
+ return rw
+}
+
+var cgiTested = false
+var cgiWorks bool
+
+func skipTest(t *testing.T) bool {
+ if !cgiTested {
+ cgiTested = true
+ cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil
+ }
+ if !cgiWorks {
+ // No Perl on Windows, needed by test.cgi
+ // TODO: make the child process be Go, not Perl.
+ t.Logf("Skipping test: test.cgi failed.")
+ return true
+ }
+ return false
+}
+
+func TestCGIBasicGet(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI",
+ "param-a": "b",
+ "param-foo": "bar",
+ "env-GATEWAY_INTERFACE": "CGI/1.1",
+ "env-HTTP_HOST": "example.com",
+ "env-PATH_INFO": "",
+ "env-QUERY_STRING": "foo=bar&a=b",
+ "env-REMOTE_ADDR": "1.2.3.4",
+ "env-REMOTE_HOST": "1.2.3.4",
+ "env-REQUEST_METHOD": "GET",
+ "env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ "env-SERVER_NAME": "example.com",
+ "env-SERVER_PORT": "80",
+ "env-SERVER_SOFTWARE": "go",
+ }
+ replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected {
+ t.Errorf("got a Content-Type of %q; expected %q", got, expected)
+ }
+ if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+ t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+ }
+}
+
+func TestCGIBasicGetAbsPath(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ pwd, err := os.Getwd()
+ if err != nil {
+ t.Fatalf("getwd error: %v", err)
+ }
+ h := &Handler{
+ Path: pwd + "/testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ }
+ runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestPathInfo(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "param-a": "b",
+ "env-PATH_INFO": "/extrapath",
+ "env-QUERY_STRING": "a=b",
+ "env-REQUEST_URI": "/test.cgi/extrapath?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ }
+ runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestPathInfoDirRoot(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/myscript/",
+ }
+ expectedMap := map[string]string{
+ "env-PATH_INFO": "bar",
+ "env-QUERY_STRING": "a=b",
+ "env-REQUEST_URI": "/myscript/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/myscript/",
+ }
+ runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestDupHeaders(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "env-REQUEST_URI": "/myscript/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-HTTP_COOKIE": "nom=NOM; yum=YUM",
+ "env-HTTP_X_FOO": "val1, val2",
+ }
+ runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+
+ "Cookie: nom=NOM\n"+
+ "Cookie: yum=YUM\n"+
+ "X-Foo: val1\n"+
+ "X-Foo: val2\n"+
+ "Host: example.com\n\n",
+ expectedMap)
+}
+
+func TestPathInfoNoRoot(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "",
+ }
+ expectedMap := map[string]string{
+ "env-PATH_INFO": "/bar",
+ "env-QUERY_STRING": "a=b",
+ "env-REQUEST_URI": "/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/",
+ }
+ runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestCGIBasicPost(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ postReq := `POST /test.cgi?a=b HTTP/1.0
+Host: example.com
+Content-Type: application/x-www-form-urlencoded
+Content-Length: 15
+
+postfoo=postbar`
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI",
+ "param-postfoo": "postbar",
+ "env-REQUEST_METHOD": "POST",
+ "env-CONTENT_LENGTH": "15",
+ "env-REQUEST_URI": "/test.cgi?a=b",
+ }
+ runCgiTest(t, h, postReq, expectedMap)
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+// The CGI spec doesn't allow chunked requests.
+func TestCGIPostChunked(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ postReq := `POST /test.cgi?a=b HTTP/1.1
+Host: example.com
+Content-Type: application/x-www-form-urlencoded
+Transfer-Encoding: chunked
+
+` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("")
+
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{}
+ resp := runCgiTest(t, h, postReq, expectedMap)
+ if got, expected := resp.Code, http.StatusBadRequest; got != expected {
+ t.Fatalf("Expected %v response code from chunked request body; got %d",
+ expected, got)
+ }
+}
+
+func TestRedirect(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil)
+ if e, g := 302, rec.Code; e != g {
+ t.Errorf("expected status code %d; got %d", e, g)
+ }
+ if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g {
+ t.Errorf("expected Location header of %q; got %q", e, g)
+ }
+}
+
+func TestInternalRedirect(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path)
+ fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr)
+ })
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ PathLocationHandler: baseHandler,
+ }
+ expectedMap := map[string]string{
+ "basepath": "/foo",
+ "remoteaddr": "1.2.3.4",
+ }
+ runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+// TestCopyError tests that we kill the process if there's an error copying
+// its output. (for example, from the client having gone away)
+func TestCopyError(t *testing.T) {
+ if skipTest(t) || runtime.GOOS == "windows" {
+ return
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ ts := httptest.NewServer(h)
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ req, _ := http.NewRequest("GET", "http://example.com/test.cgi?bigresponse=1", nil)
+ err = req.Write(conn)
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+
+ res, err := http.ReadResponse(bufio.NewReader(conn), req)
+ if err != nil {
+ t.Fatalf("ReadResponse: %v", err)
+ }
+
+ pidstr := res.Header.Get("X-CGI-Pid")
+ if pidstr == "" {
+ t.Fatalf("expected an X-CGI-Pid header in response")
+ }
+ pid, err := strconv.Atoi(pidstr)
+ if err != nil {
+ t.Fatalf("invalid X-CGI-Pid value")
+ }
+
+ var buf [5000]byte
+ n, err := io.ReadFull(res.Body, buf[:])
+ if err != nil {
+ t.Fatalf("ReadFull: %d bytes, %v", n, err)
+ }
+
+ childRunning := func() bool {
+ p, err := os.FindProcess(pid)
+ if err != nil {
+ return false
+ }
+ return p.Signal(os.UnixSignal(0)) == nil
+ }
+
+ if !childRunning() {
+ t.Fatalf("pre-conn.Close, expected child to be running")
+ }
+ conn.Close()
+
+ tries := 0
+ for tries < 25 && childRunning() {
+ time.Sleep(50 * time.Millisecond * time.Duration(tries))
+ tries++
+ }
+ if childRunning() {
+ t.Fatalf("post-conn.Close, expected child to be gone")
+ }
+}
+
+func TestDirUnix(t *testing.T) {
+ if skipTest(t) || runtime.GOOS == "windows" {
+ return
+ }
+
+ cwd, _ := os.Getwd()
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ Dir: cwd,
+ }
+ expectedMap := map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ cwd, _ = os.Getwd()
+ cwd = filepath.Join(cwd, "testdata")
+ h = &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap = map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestDirWindows(t *testing.T) {
+ if skipTest(t) || runtime.GOOS != "windows" {
+ return
+ }
+
+ cgifile, _ := filepath.Abs("testdata/test.cgi")
+
+ var perl string
+ var err error
+ perl, err = exec.LookPath("perl")
+ if err != nil {
+ return
+ }
+ perl, _ = filepath.Abs(perl)
+
+ cwd, _ := os.Getwd()
+ h := &Handler{
+ Path: perl,
+ Root: "/test.cgi",
+ Dir: cwd,
+ Args: []string{cgifile},
+ Env: []string{"SCRIPT_FILENAME=" + cgifile},
+ }
+ expectedMap := map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ // If not specify Dir on windows, working directory should be
+ // base directory of perl.
+ cwd, _ = filepath.Split(perl)
+ if cwd != "" && cwd[len(cwd)-1] == filepath.Separator {
+ cwd = cwd[:len(cwd)-1]
+ }
+ h = &Handler{
+ Path: perl,
+ Root: "/test.cgi",
+ Args: []string{cgifile},
+ Env: []string{"SCRIPT_FILENAME=" + cgifile},
+ }
+ expectedMap = map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestEnvOverride(t *testing.T) {
+ cgifile, _ := filepath.Abs("testdata/test.cgi")
+
+ var perl string
+ var err error
+ perl, err = exec.LookPath("perl")
+ if err != nil {
+ return
+ }
+ perl, _ = filepath.Abs(perl)
+
+ cwd, _ := os.Getwd()
+ h := &Handler{
+ Path: perl,
+ Root: "/test.cgi",
+ Dir: cwd,
+ Args: []string{cgifile},
+ Env: []string{
+ "SCRIPT_FILENAME=" + cgifile,
+ "REQUEST_URI=/foo/bar"},
+ }
+ expectedMap := map[string]string{
+ "cwd": cwd,
+ "env-SCRIPT_FILENAME": cgifile,
+ "env-REQUEST_URI": "/foo/bar",
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
diff --git a/src/pkg/net/http/cgi/matryoshka_test.go b/src/pkg/net/http/cgi/matryoshka_test.go
new file mode 100644
index 000000000..1a44df204
--- /dev/null
+++ b/src/pkg/net/http/cgi/matryoshka_test.go
@@ -0,0 +1,74 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests a Go CGI program running under a Go CGI host process.
+// Further, the two programs are the same binary, just checking
+// their environment to figure out what mode to run in.
+
+package cgi
+
+import (
+ "fmt"
+ "net/http"
+ "os"
+ "testing"
+)
+
+// This test is a CGI host (testing host.go) that runs its own binary
+// as a child process testing the other half of CGI (child.go).
+func TestHostingOurselves(t *testing.T) {
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI-in-CGI",
+ "param-a": "b",
+ "param-foo": "bar",
+ "env-GATEWAY_INTERFACE": "CGI/1.1",
+ "env-HTTP_HOST": "example.com",
+ "env-PATH_INFO": "",
+ "env-QUERY_STRING": "foo=bar&a=b",
+ "env-REMOTE_ADDR": "1.2.3.4",
+ "env-REMOTE_HOST": "1.2.3.4",
+ "env-REQUEST_METHOD": "GET",
+ "env-REQUEST_URI": "/test.go?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": os.Args[0],
+ "env-SCRIPT_NAME": "/test.go",
+ "env-SERVER_NAME": "example.com",
+ "env-SERVER_PORT": "80",
+ "env-SERVER_SOFTWARE": "go",
+ }
+ replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ if expected, got := "text/html; charset=utf-8", replay.Header().Get("Content-Type"); got != expected {
+ t.Errorf("got a Content-Type of %q; expected %q", got, expected)
+ }
+ if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+ t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+ }
+}
+
+// Note: not actually a test.
+func TestBeChildCGIProcess(t *testing.T) {
+ if os.Getenv("REQUEST_METHOD") == "" {
+ // Not in a CGI environment; skipping test.
+ return
+ }
+ Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("X-Test-Header", "X-Test-Value")
+ fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n")
+ req.ParseForm()
+ for k, vv := range req.Form {
+ for _, v := range vv {
+ fmt.Fprintf(rw, "param-%s=%s\n", k, v)
+ }
+ }
+ for _, kv := range os.Environ() {
+ fmt.Fprintf(rw, "env-%s\n", kv)
+ }
+ }))
+ os.Exit(0)
+}
diff --git a/src/pkg/net/http/cgi/testdata/test.cgi b/src/pkg/net/http/cgi/testdata/test.cgi
new file mode 100755
index 000000000..b46b1330f
--- /dev/null
+++ b/src/pkg/net/http/cgi/testdata/test.cgi
@@ -0,0 +1,96 @@
+#!/usr/bin/perl
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+#
+# Test script run as a child process under cgi_test.go
+
+use strict;
+use Cwd;
+
+my $q = MiniCGI->new;
+my $params = $q->Vars;
+
+if ($params->{"loc"}) {
+ print "Location: $params->{loc}\r\n\r\n";
+ exit(0);
+}
+
+my $NL = "\r\n";
+$NL = "\n" if $params->{mode} eq "NL";
+
+my $p = sub {
+ print "$_[0]$NL";
+};
+
+# With carriage returns
+$p->("Content-Type: text/html");
+$p->("X-CGI-Pid: $$");
+$p->("X-Test-Header: X-Test-Value");
+$p->("");
+
+if ($params->{"bigresponse"}) {
+ for (1..1024) {
+ print "A" x 1024, "\n";
+ }
+ exit 0;
+}
+
+print "test=Hello CGI\n";
+
+foreach my $k (sort keys %$params) {
+ print "param-$k=$params->{$k}\n";
+}
+
+foreach my $k (sort keys %ENV) {
+ my $clean_env = $ENV{$k};
+ $clean_env =~ s/[\n\r]//g;
+ print "env-$k=$clean_env\n";
+}
+
+# NOTE: don't call getcwd() for windows.
+# msys return /c/go/src/... not C:\go\...
+my $dir;
+if ($^O eq 'MSWin32' || $^O eq 'msys') {
+ my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe';
+ $cmd =~ s!\\!/!g;
+ $dir = `$cmd /c cd`;
+ chomp $dir;
+} else {
+ $dir = getcwd();
+}
+print "cwd=$dir\n";
+
+
+# A minimal version of CGI.pm, for people without the perl-modules
+# package installed. (CGI.pm used to be part of the Perl core, but
+# some distros now bundle perl-base and perl-modules separately...)
+package MiniCGI;
+
+sub new {
+ my $class = shift;
+ return bless {}, $class;
+}
+
+sub Vars {
+ my $self = shift;
+ my $pairs;
+ if ($ENV{CONTENT_LENGTH}) {
+ $pairs = do { local $/; <STDIN> };
+ } else {
+ $pairs = $ENV{QUERY_STRING};
+ }
+ my $vars = {};
+ foreach my $kv (split(/&/, $pairs)) {
+ my ($k, $v) = split(/=/, $kv, 2);
+ $vars->{_urldecode($k)} = _urldecode($v);
+ }
+ return $vars;
+}
+
+sub _urldecode {
+ my $v = shift;
+ $v =~ tr/+/ /;
+ $v =~ s/%([a-fA-F0-9][a-fA-F0-9])/pack("C", hex($1))/eg;
+ return $v;
+}
diff --git a/src/pkg/net/http/chunked.go b/src/pkg/net/http/chunked.go
new file mode 100644
index 000000000..60a478fd8
--- /dev/null
+++ b/src/pkg/net/http/chunked.go
@@ -0,0 +1,170 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The wire protocol for HTTP's "chunked" Transfer-Encoding.
+
+// This code is duplicated in httputil/chunked.go.
+// Please make any changes in both files.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "io"
+ "strconv"
+)
+
+const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
+
+var ErrLineTooLong = errors.New("header line too long")
+
+// newChunkedReader returns a new chunkedReader that translates the data read from r
+// out of HTTP "chunked" format before returning it.
+// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+//
+// newChunkedReader is not needed by normal applications. The http package
+// automatically decodes chunking when reading response bodies.
+func newChunkedReader(r io.Reader) io.Reader {
+ br, ok := r.(*bufio.Reader)
+ if !ok {
+ br = bufio.NewReader(r)
+ }
+ return &chunkedReader{r: br}
+}
+
+type chunkedReader struct {
+ r *bufio.Reader
+ n uint64 // unread bytes in chunk
+ err error
+}
+
+func (cr *chunkedReader) beginChunk() {
+ // chunk-size CRLF
+ var line string
+ line, cr.err = readLine(cr.r)
+ if cr.err != nil {
+ return
+ }
+ cr.n, cr.err = strconv.ParseUint(line, 16, 64)
+ if cr.err != nil {
+ return
+ }
+ if cr.n == 0 {
+ cr.err = io.EOF
+ }
+}
+
+func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
+ if cr.err != nil {
+ return 0, cr.err
+ }
+ if cr.n == 0 {
+ cr.beginChunk()
+ if cr.err != nil {
+ return 0, cr.err
+ }
+ }
+ if uint64(len(b)) > cr.n {
+ b = b[0:cr.n]
+ }
+ n, cr.err = cr.r.Read(b)
+ cr.n -= uint64(n)
+ if cr.n == 0 && cr.err == nil {
+ // end of chunk (CRLF)
+ b := make([]byte, 2)
+ if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
+ if b[0] != '\r' || b[1] != '\n' {
+ cr.err = errors.New("malformed chunked encoding")
+ }
+ }
+ }
+ return n, cr.err
+}
+
+// Read a line of bytes (up to \n) from b.
+// Give up if the line exceeds maxLineLength.
+// The returned bytes are a pointer into storage in
+// the bufio, so they are only valid until the next bufio read.
+func readLineBytes(b *bufio.Reader) (p []byte, err error) {
+ if p, err = b.ReadSlice('\n'); err != nil {
+ // We always know when EOF is coming.
+ // If the caller asked for a line, there should be a line.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ } else if err == bufio.ErrBufferFull {
+ err = ErrLineTooLong
+ }
+ return nil, err
+ }
+ if len(p) >= maxLineLength {
+ return nil, ErrLineTooLong
+ }
+
+ // Chop off trailing white space.
+ p = bytes.TrimRight(p, " \r\t\n")
+
+ return p, nil
+}
+
+// readLineBytes, but convert the bytes into a string.
+func readLine(b *bufio.Reader) (s string, err error) {
+ p, e := readLineBytes(b)
+ if e != nil {
+ return "", e
+ }
+ return string(p), nil
+}
+
+// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
+// "chunked" format before writing them to w. Closing the returned chunkedWriter
+// sends the final 0-length chunk that marks the end of the stream.
+//
+// newChunkedWriter is not needed by normal applications. The http
+// package adds chunking automatically if handlers don't set a
+// Content-Length header. Using newChunkedWriter inside a handler
+// would result in double chunking or chunking with a Content-Length
+// length, both of which are wrong.
+func newChunkedWriter(w io.Writer) io.WriteCloser {
+ return &chunkedWriter{w}
+}
+
+// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
+// Encoding wire format to the underlying Wire chunkedWriter.
+type chunkedWriter struct {
+ Wire io.Writer
+}
+
+// Write the contents of data as one chunk to Wire.
+// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has
+// a bug since it does not check for success of io.WriteString
+func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
+
+ // Don't send 0-length data. It looks like EOF for chunked encoding.
+ if len(data) == 0 {
+ return 0, nil
+ }
+
+ head := strconv.FormatInt(int64(len(data)), 16) + "\r\n"
+
+ if _, err = io.WriteString(cw.Wire, head); err != nil {
+ return 0, err
+ }
+ if n, err = cw.Wire.Write(data); err != nil {
+ return
+ }
+ if n != len(data) {
+ err = io.ErrShortWrite
+ return
+ }
+ _, err = io.WriteString(cw.Wire, "\r\n")
+
+ return
+}
+
+func (cw *chunkedWriter) Close() error {
+ _, err := io.WriteString(cw.Wire, "0\r\n")
+ return err
+}
diff --git a/src/pkg/net/http/chunked_test.go b/src/pkg/net/http/chunked_test.go
new file mode 100644
index 000000000..b77ee2ff2
--- /dev/null
+++ b/src/pkg/net/http/chunked_test.go
@@ -0,0 +1,39 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This code is duplicated in httputil/chunked_test.go.
+// Please make any changes in both files.
+
+package http
+
+import (
+ "bytes"
+ "io/ioutil"
+ "testing"
+)
+
+func TestChunk(t *testing.T) {
+ var b bytes.Buffer
+
+ w := newChunkedWriter(&b)
+ const chunk1 = "hello, "
+ const chunk2 = "world! 0123456789abcdef"
+ w.Write([]byte(chunk1))
+ w.Write([]byte(chunk2))
+ w.Close()
+
+ if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e {
+ t.Fatalf("chunk writer wrote %q; want %q", g, e)
+ }
+
+ r := newChunkedReader(&b)
+ data, err := ioutil.ReadAll(r)
+ if err != nil {
+ t.Logf(`data: "%s"`, data)
+ t.Fatalf("ReadAll from reader: %v", err)
+ }
+ if g, e := string(data), chunk1+chunk2; g != e {
+ t.Errorf("chunk reader read %q; want %q", g, e)
+ }
+}
diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go
new file mode 100644
index 000000000..c9f024017
--- /dev/null
+++ b/src/pkg/net/http/client.go
@@ -0,0 +1,326 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP client. See RFC 2616.
+//
+// This is the high-level Client interface.
+// The low-level implementation is in transport.go.
+
+package http
+
+import (
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net/url"
+ "strings"
+)
+
+// A Client is an HTTP client. Its zero value (DefaultClient) is a usable client
+// that uses DefaultTransport.
+//
+// The Client's Transport typically has internal state (cached
+// TCP connections), so Clients should be reused instead of created as
+// needed. Clients are safe for concurrent use by multiple goroutines.
+type Client struct {
+ // Transport specifies the mechanism by which individual
+ // HTTP requests are made.
+ // If nil, DefaultTransport is used.
+ Transport RoundTripper
+
+ // CheckRedirect specifies the policy for handling redirects.
+ // If CheckRedirect is not nil, the client calls it before
+ // following an HTTP redirect. The arguments req and via
+ // are the upcoming request and the requests made already,
+ // oldest first. If CheckRedirect returns an error, the client
+ // returns that error instead of issue the Request req.
+ //
+ // If CheckRedirect is nil, the Client uses its default policy,
+ // which is to stop after 10 consecutive requests.
+ CheckRedirect func(req *Request, via []*Request) error
+
+ // Jar specifies the cookie jar.
+ // If Jar is nil, cookies are not sent in requests and ignored
+ // in responses.
+ Jar CookieJar
+}
+
+// DefaultClient is the default Client and is used by Get, Head, and Post.
+var DefaultClient = &Client{}
+
+// RoundTripper is an interface representing the ability to execute a
+// single HTTP transaction, obtaining the Response for a given Request.
+//
+// A RoundTripper must be safe for concurrent use by multiple
+// goroutines.
+type RoundTripper interface {
+ // RoundTrip executes a single HTTP transaction, returning
+ // the Response for the request req. RoundTrip should not
+ // attempt to interpret the response. In particular,
+ // RoundTrip must return err == nil if it obtained a response,
+ // regardless of the response's HTTP status code. A non-nil
+ // err should be reserved for failure to obtain a response.
+ // Similarly, RoundTrip should not attempt to handle
+ // higher-level protocol details such as redirects,
+ // authentication, or cookies.
+ //
+ // RoundTrip should not modify the request, except for
+ // consuming the Body. The request's URL and Header fields
+ // are guaranteed to be initialized.
+ RoundTrip(*Request) (*Response, error)
+}
+
+// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
+// return true if the string includes a port.
+func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
+
+// Used in Send to implement io.ReadCloser by bundling together the
+// bufio.Reader through which we read the response, and the underlying
+// network connection.
+type readClose struct {
+ io.Reader
+ io.Closer
+}
+
+// Do sends an HTTP request and returns an HTTP response, following
+// policy (e.g. redirects, cookies, auth) as configured on the client.
+//
+// A non-nil response always contains a non-nil resp.Body.
+//
+// Callers should close resp.Body when done reading from it. If
+// resp.Body is not closed, the Client's underlying RoundTripper
+// (typically Transport) may not be able to re-use a persistent TCP
+// connection to the server for a subsequent "keep-alive" request.
+//
+// Generally Get, Post, or PostForm will be used instead of Do.
+func (c *Client) Do(req *Request) (resp *Response, err error) {
+ if req.Method == "GET" || req.Method == "HEAD" {
+ return c.doFollowingRedirects(req)
+ }
+ return send(req, c.Transport)
+}
+
+// send issues an HTTP request. Caller should close resp.Body when done reading from it.
+func send(req *Request, t RoundTripper) (resp *Response, err error) {
+ if t == nil {
+ t = DefaultTransport
+ if t == nil {
+ err = errors.New("http: no Client.Transport or DefaultTransport")
+ return
+ }
+ }
+
+ if req.URL == nil {
+ return nil, errors.New("http: nil Request.URL")
+ }
+
+ if req.RequestURI != "" {
+ return nil, errors.New("http: Request.RequestURI can't be set in client requests.")
+ }
+
+ // Most the callers of send (Get, Post, et al) don't need
+ // Headers, leaving it uninitialized. We guarantee to the
+ // Transport that this has been initialized, though.
+ if req.Header == nil {
+ req.Header = make(Header)
+ }
+
+ if u := req.URL.User; u != nil {
+ req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(u.String())))
+ }
+ return t.RoundTrip(req)
+}
+
+// True if the specified HTTP status code is one for which the Get utility should
+// automatically redirect.
+func shouldRedirect(statusCode int) bool {
+ switch statusCode {
+ case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect:
+ return true
+ }
+ return false
+}
+
+// Get issues a GET to the specified URL. If the response is one of the following
+// redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+//
+// Caller should close r.Body when done reading from it.
+//
+// Get is a wrapper around DefaultClient.Get.
+func Get(url string) (r *Response, err error) {
+ return DefaultClient.Get(url)
+}
+
+// Get issues a GET to the specified URL. If the response is one of the
+// following redirect codes, Get follows the redirect after calling the
+// Client's CheckRedirect function.
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+//
+// Caller should close r.Body when done reading from it.
+func (c *Client) Get(url string) (r *Response, err error) {
+ req, err := NewRequest("GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return c.doFollowingRedirects(req)
+}
+
+func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
+ // TODO: if/when we add cookie support, the redirected request shouldn't
+ // necessarily supply the same cookies as the original.
+ var base *url.URL
+ redirectChecker := c.CheckRedirect
+ if redirectChecker == nil {
+ redirectChecker = defaultCheckRedirect
+ }
+ var via []*Request
+
+ if ireq.URL == nil {
+ return nil, errors.New("http: nil Request.URL")
+ }
+
+ jar := c.Jar
+ if jar == nil {
+ jar = blackHoleJar{}
+ }
+
+ req := ireq
+ urlStr := "" // next relative or absolute URL to fetch (after first request)
+ for redirect := 0; ; redirect++ {
+ if redirect != 0 {
+ req = new(Request)
+ req.Method = ireq.Method
+ req.Header = make(Header)
+ req.URL, err = base.Parse(urlStr)
+ if err != nil {
+ break
+ }
+ if len(via) > 0 {
+ // Add the Referer header.
+ lastReq := via[len(via)-1]
+ if lastReq.URL.Scheme != "https" {
+ req.Header.Set("Referer", lastReq.URL.String())
+ }
+
+ err = redirectChecker(req, via)
+ if err != nil {
+ break
+ }
+ }
+ }
+
+ for _, cookie := range jar.Cookies(req.URL) {
+ req.AddCookie(cookie)
+ }
+ urlStr = req.URL.String()
+ if r, err = send(req, c.Transport); err != nil {
+ break
+ }
+ if c := r.Cookies(); len(c) > 0 {
+ jar.SetCookies(req.URL, c)
+ }
+
+ if shouldRedirect(r.StatusCode) {
+ r.Body.Close()
+ if urlStr = r.Header.Get("Location"); urlStr == "" {
+ err = errors.New(fmt.Sprintf("%d response missing Location header", r.StatusCode))
+ break
+ }
+ base = req.URL
+ via = append(via, req)
+ continue
+ }
+ return
+ }
+
+ method := ireq.Method
+ err = &url.Error{method[0:1] + strings.ToLower(method[1:]), urlStr, err}
+ return
+}
+
+func defaultCheckRedirect(req *Request, via []*Request) error {
+ if len(via) >= 10 {
+ return errors.New("stopped after 10 redirects")
+ }
+ return nil
+}
+
+// Post issues a POST to the specified URL.
+//
+// Caller should close r.Body when done reading from it.
+//
+// Post is a wrapper around DefaultClient.Post
+func Post(url string, bodyType string, body io.Reader) (r *Response, err error) {
+ return DefaultClient.Post(url, bodyType, body)
+}
+
+// Post issues a POST to the specified URL.
+//
+// Caller should close r.Body when done reading from it.
+func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err error) {
+ req, err := NewRequest("POST", url, body)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", bodyType)
+ return send(req, c.Transport)
+}
+
+// PostForm issues a POST to the specified URL,
+// with data's keys and values urlencoded as the request body.
+//
+// Caller should close r.Body when done reading from it.
+//
+// PostForm is a wrapper around DefaultClient.PostForm
+func PostForm(url string, data url.Values) (r *Response, err error) {
+ return DefaultClient.PostForm(url, data)
+}
+
+// PostForm issues a POST to the specified URL,
+// with data's keys and values urlencoded as the request body.
+//
+// Caller should close r.Body when done reading from it.
+func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) {
+ return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
+}
+
+// Head issues a HEAD to the specified URL. If the response is one of the
+// following redirect codes, Head follows the redirect after calling the
+// Client's CheckRedirect function.
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+//
+// Head is a wrapper around DefaultClient.Head
+func Head(url string) (r *Response, err error) {
+ return DefaultClient.Head(url)
+}
+
+// Head issues a HEAD to the specified URL. If the response is one of the
+// following redirect codes, Head follows the redirect after calling the
+// Client's CheckRedirect function.
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+func (c *Client) Head(url string) (r *Response, err error) {
+ req, err := NewRequest("HEAD", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return c.doFollowingRedirects(req)
+}
diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go
new file mode 100644
index 000000000..aa0bf4be6
--- /dev/null
+++ b/src/pkg/net/http/client_test.go
@@ -0,0 +1,442 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for client.go
+
+package http_test
+
+import (
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+)
+
+var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Last-Modified", "sometime")
+ fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
+})
+
+// pedanticReadAll works like ioutil.ReadAll but additionally
+// verifies that r obeys the documented io.Reader contract.
+func pedanticReadAll(r io.Reader) (b []byte, err error) {
+ var bufa [64]byte
+ buf := bufa[:]
+ for {
+ n, err := r.Read(buf)
+ if n == 0 && err == nil {
+ return nil, fmt.Errorf("Read: n=0 with err=nil")
+ }
+ b = append(b, buf[:n]...)
+ if err == io.EOF {
+ n, err := r.Read(buf)
+ if n != 0 || err != io.EOF {
+ return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
+ }
+ return b, nil
+ }
+ if err != nil {
+ return b, err
+ }
+ }
+ panic("unreachable")
+}
+
+func TestClient(t *testing.T) {
+ ts := httptest.NewServer(robotsTxtHandler)
+ defer ts.Close()
+
+ r, err := Get(ts.URL)
+ var b []byte
+ if err == nil {
+ b, err = pedanticReadAll(r.Body)
+ r.Body.Close()
+ }
+ if err != nil {
+ t.Error(err)
+ } else if s := string(b); !strings.HasPrefix(s, "User-agent:") {
+ t.Errorf("Incorrect page body (did not begin with User-agent): %q", s)
+ }
+}
+
+func TestClientHead(t *testing.T) {
+ ts := httptest.NewServer(robotsTxtHandler)
+ defer ts.Close()
+
+ r, err := Head(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, ok := r.Header["Last-Modified"]; !ok {
+ t.Error("Last-Modified header not found.")
+ }
+}
+
+type recordingTransport struct {
+ req *Request
+}
+
+func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) {
+ t.req = req
+ return nil, errors.New("dummy impl")
+}
+
+func TestGetRequestFormat(t *testing.T) {
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+ url := "http://dummy.faketld/"
+ client.Get(url) // Note: doesn't hit network
+ if tr.req.Method != "GET" {
+ t.Errorf("expected method %q; got %q", "GET", tr.req.Method)
+ }
+ if tr.req.URL.String() != url {
+ t.Errorf("expected URL %q; got %q", url, tr.req.URL.String())
+ }
+ if tr.req.Header == nil {
+ t.Errorf("expected non-nil request Header")
+ }
+}
+
+func TestPostRequestFormat(t *testing.T) {
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+
+ url := "http://dummy.faketld/"
+ json := `{"key":"value"}`
+ b := strings.NewReader(json)
+ client.Post(url, "application/json", b) // Note: doesn't hit network
+
+ if tr.req.Method != "POST" {
+ t.Errorf("got method %q, want %q", tr.req.Method, "POST")
+ }
+ if tr.req.URL.String() != url {
+ t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
+ }
+ if tr.req.Header == nil {
+ t.Fatalf("expected non-nil request Header")
+ }
+ if tr.req.Close {
+ t.Error("got Close true, want false")
+ }
+ if g, e := tr.req.ContentLength, int64(len(json)); g != e {
+ t.Errorf("got ContentLength %d, want %d", g, e)
+ }
+}
+
+func TestPostFormRequestFormat(t *testing.T) {
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+
+ urlStr := "http://dummy.faketld/"
+ form := make(url.Values)
+ form.Set("foo", "bar")
+ form.Add("foo", "bar2")
+ form.Set("bar", "baz")
+ client.PostForm(urlStr, form) // Note: doesn't hit network
+
+ if tr.req.Method != "POST" {
+ t.Errorf("got method %q, want %q", tr.req.Method, "POST")
+ }
+ if tr.req.URL.String() != urlStr {
+ t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr)
+ }
+ if tr.req.Header == nil {
+ t.Fatalf("expected non-nil request Header")
+ }
+ if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e {
+ t.Errorf("got Content-Type %q, want %q", g, e)
+ }
+ if tr.req.Close {
+ t.Error("got Close true, want false")
+ }
+ // Depending on map iteration, body can be either of these.
+ expectedBody := "foo=bar&foo=bar2&bar=baz"
+ expectedBody1 := "bar=baz&foo=bar&foo=bar2"
+ if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e {
+ t.Errorf("got ContentLength %d, want %d", g, e)
+ }
+ bodyb, err := ioutil.ReadAll(tr.req.Body)
+ if err != nil {
+ t.Fatalf("ReadAll on req.Body: %v", err)
+ }
+ if g := string(bodyb); g != expectedBody && g != expectedBody1 {
+ t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1)
+ }
+}
+
+func TestRedirects(t *testing.T) {
+ var ts *httptest.Server
+ ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ n, _ := strconv.Atoi(r.FormValue("n"))
+ // Test Referer header. (7 is arbitrary position to test at)
+ if n == 7 {
+ if g, e := r.Referer(), ts.URL+"/?n=6"; e != g {
+ t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g)
+ }
+ }
+ if n < 15 {
+ Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound)
+ return
+ }
+ fmt.Fprintf(w, "n=%d", n)
+ }))
+ defer ts.Close()
+
+ c := &Client{}
+ _, err := c.Get(ts.URL)
+ if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Get, expected error %q, got %q", e, g)
+ }
+
+ // HEAD request should also have the ability to follow redirects.
+ _, err = c.Head(ts.URL)
+ if e, g := "Head /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Head, expected error %q, got %q", e, g)
+ }
+
+ // Do should also follow redirects.
+ greq, _ := NewRequest("GET", ts.URL, nil)
+ _, err = c.Do(greq)
+ if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Do, expected error %q, got %q", e, g)
+ }
+
+ var checkErr error
+ var lastVia []*Request
+ c = &Client{CheckRedirect: func(_ *Request, via []*Request) error {
+ lastVia = via
+ return checkErr
+ }}
+ res, err := c.Get(ts.URL)
+ finalUrl := res.Request.URL.String()
+ if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with custom client, expected error %q, got %q", e, g)
+ }
+ if !strings.HasSuffix(finalUrl, "/?n=15") {
+ t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl)
+ }
+ if e, g := 15, len(lastVia); e != g {
+ t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
+ }
+
+ checkErr = errors.New("no redirects allowed")
+ res, err = c.Get(ts.URL)
+ finalUrl = res.Request.URL.String()
+ if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with redirects forbidden, expected error %q, got %q", e, g)
+ }
+}
+
+var expectedCookies = []*Cookie{
+ &Cookie{Name: "ChocolateChip", Value: "tasty"},
+ &Cookie{Name: "First", Value: "Hit"},
+ &Cookie{Name: "Second", Value: "Hit"},
+}
+
+var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ for _, cookie := range r.Cookies() {
+ SetCookie(w, cookie)
+ }
+ if r.URL.Path == "/" {
+ SetCookie(w, expectedCookies[1])
+ Redirect(w, r, "/second", StatusMovedPermanently)
+ } else {
+ SetCookie(w, expectedCookies[2])
+ w.Write([]byte("hello"))
+ }
+})
+
+// Just enough correctness for our redirect tests. Uses the URL.Host as the
+// scope of all cookies.
+type TestJar struct {
+ m sync.Mutex
+ perURL map[string][]*Cookie
+}
+
+func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) {
+ j.m.Lock()
+ defer j.m.Unlock()
+ j.perURL[u.Host] = cookies
+}
+
+func (j *TestJar) Cookies(u *url.URL) []*Cookie {
+ j.m.Lock()
+ defer j.m.Unlock()
+ return j.perURL[u.Host]
+}
+
+func TestRedirectCookiesOnRequest(t *testing.T) {
+ var ts *httptest.Server
+ ts = httptest.NewServer(echoCookiesRedirectHandler)
+ defer ts.Close()
+ c := &Client{}
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.AddCookie(expectedCookies[0])
+ // TODO: Uncomment when an implementation of a RFC6265 cookie jar lands.
+ _ = c
+ // resp, _ := c.Do(req)
+ // matchReturnedCookies(t, expectedCookies, resp.Cookies())
+
+ req, _ = NewRequest("GET", ts.URL, nil)
+ // resp, _ = c.Do(req)
+ // matchReturnedCookies(t, expectedCookies[1:], resp.Cookies())
+}
+
+func TestRedirectCookiesJar(t *testing.T) {
+ var ts *httptest.Server
+ ts = httptest.NewServer(echoCookiesRedirectHandler)
+ defer ts.Close()
+ c := &Client{}
+ c.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
+ u, _ := url.Parse(ts.URL)
+ c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
+ resp, _ := c.Get(ts.URL)
+ matchReturnedCookies(t, expectedCookies, resp.Cookies())
+}
+
+func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
+ t.Logf("Received cookies: %v", given)
+ if len(given) != len(expected) {
+ t.Errorf("Expected %d cookies, got %d", len(expected), len(given))
+ }
+ for _, ec := range expected {
+ foundC := false
+ for _, c := range given {
+ if ec.Name == c.Name && ec.Value == c.Value {
+ foundC = true
+ break
+ }
+ }
+ if !foundC {
+ t.Errorf("Missing cookie %v", ec)
+ }
+ }
+}
+
+func TestStreamingGet(t *testing.T) {
+ say := make(chan string)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.(Flusher).Flush()
+ for str := range say {
+ w.Write([]byte(str))
+ w.(Flusher).Flush()
+ }
+ }))
+ defer ts.Close()
+
+ c := &Client{}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var buf [10]byte
+ for _, str := range []string{"i", "am", "also", "known", "as", "comet"} {
+ say <- str
+ n, err := io.ReadFull(res.Body, buf[0:len(str)])
+ if err != nil {
+ t.Fatalf("ReadFull on %q: %v", str, err)
+ }
+ if n != len(str) {
+ t.Fatalf("Receiving %q, only read %d bytes", str, n)
+ }
+ got := string(buf[0:n])
+ if got != str {
+ t.Fatalf("Expected %q, got %q", str, got)
+ }
+ }
+ close(say)
+ _, err = io.ReadFull(res.Body, buf[0:1])
+ if err != io.EOF {
+ t.Fatalf("at end expected EOF, got %v", err)
+ }
+}
+
+type writeCountingConn struct {
+ net.Conn
+ count *int
+}
+
+func (c *writeCountingConn) Write(p []byte) (int, error) {
+ *c.count++
+ return c.Conn.Write(p)
+}
+
+// TestClientWrites verifies that client requests are buffered and we
+// don't send a TCP packet per line of the http request + body.
+func TestClientWrites(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ }))
+ defer ts.Close()
+
+ writes := 0
+ dialer := func(netz string, addr string) (net.Conn, error) {
+ c, err := net.Dial(netz, addr)
+ if err == nil {
+ c = &writeCountingConn{c, &writes}
+ }
+ return c, err
+ }
+ c := &Client{Transport: &Transport{Dial: dialer}}
+
+ _, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if writes != 1 {
+ t.Errorf("Get request did %d Write calls, want 1", writes)
+ }
+
+ writes = 0
+ _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if writes != 1 {
+ t.Errorf("Post request did %d Write calls, want 1", writes)
+ }
+}
+
+func TestClientInsecureTransport(t *testing.T) {
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello"))
+ }))
+ defer ts.Close()
+
+ // TODO(bradfitz): add tests for skipping hostname checks too?
+ // would require a new cert for testing, and probably
+ // redundant with these tests.
+ for _, insecure := range []bool{true, false} {
+ tr := &Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: insecure,
+ },
+ }
+ c := &Client{Transport: tr}
+ _, err := c.Get(ts.URL)
+ if (err == nil) != insecure {
+ t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
+ }
+ }
+}
+
+func TestClientErrorWithRequestURI(t *testing.T) {
+ req, _ := NewRequest("GET", "http://localhost:1234/", nil)
+ req.RequestURI = "/this/field/is/illegal/and/should/error/"
+ _, err := DefaultClient.Do(req)
+ if err == nil {
+ t.Fatalf("expected an error")
+ }
+ if !strings.Contains(err.Error(), "RequestURI") {
+ t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
+ }
+}
diff --git a/src/pkg/net/http/cookie.go b/src/pkg/net/http/cookie.go
new file mode 100644
index 000000000..2e30bbff1
--- /dev/null
+++ b/src/pkg/net/http/cookie.go
@@ -0,0 +1,267 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// This implementation is done according to RFC 6265:
+//
+// http://tools.ietf.org/html/rfc6265
+
+// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an
+// HTTP response or the Cookie header of an HTTP request.
+type Cookie struct {
+ Name string
+ Value string
+ Path string
+ Domain string
+ Expires time.Time
+ RawExpires string
+
+ // MaxAge=0 means no 'Max-Age' attribute specified.
+ // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
+ // MaxAge>0 means Max-Age attribute present and given in seconds
+ MaxAge int
+ Secure bool
+ HttpOnly bool
+ Raw string
+ Unparsed []string // Raw text of unparsed attribute-value pairs
+}
+
+// readSetCookies parses all "Set-Cookie" values from
+// the header h and returns the successfully parsed Cookies.
+func readSetCookies(h Header) []*Cookie {
+ cookies := []*Cookie{}
+ for _, line := range h["Set-Cookie"] {
+ parts := strings.Split(strings.TrimSpace(line), ";")
+ if len(parts) == 1 && parts[0] == "" {
+ continue
+ }
+ parts[0] = strings.TrimSpace(parts[0])
+ j := strings.Index(parts[0], "=")
+ if j < 0 {
+ continue
+ }
+ name, value := parts[0][:j], parts[0][j+1:]
+ if !isCookieNameValid(name) {
+ continue
+ }
+ value, success := parseCookieValue(value)
+ if !success {
+ continue
+ }
+ c := &Cookie{
+ Name: name,
+ Value: value,
+ Raw: line,
+ }
+ for i := 1; i < len(parts); i++ {
+ parts[i] = strings.TrimSpace(parts[i])
+ if len(parts[i]) == 0 {
+ continue
+ }
+
+ attr, val := parts[i], ""
+ if j := strings.Index(attr, "="); j >= 0 {
+ attr, val = attr[:j], attr[j+1:]
+ }
+ lowerAttr := strings.ToLower(attr)
+ parseCookieValueFn := parseCookieValue
+ if lowerAttr == "expires" {
+ parseCookieValueFn = parseCookieExpiresValue
+ }
+ val, success = parseCookieValueFn(val)
+ if !success {
+ c.Unparsed = append(c.Unparsed, parts[i])
+ continue
+ }
+ switch lowerAttr {
+ case "secure":
+ c.Secure = true
+ continue
+ case "httponly":
+ c.HttpOnly = true
+ continue
+ case "domain":
+ c.Domain = val
+ // TODO: Add domain parsing
+ continue
+ case "max-age":
+ secs, err := strconv.Atoi(val)
+ if err != nil || secs != 0 && val[0] == '0' {
+ break
+ }
+ if secs <= 0 {
+ c.MaxAge = -1
+ } else {
+ c.MaxAge = secs
+ }
+ continue
+ case "expires":
+ c.RawExpires = val
+ exptime, err := time.Parse(time.RFC1123, val)
+ if err != nil {
+ exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val)
+ if err != nil {
+ c.Expires = time.Time{}
+ break
+ }
+ }
+ c.Expires = exptime.UTC()
+ continue
+ case "path":
+ c.Path = val
+ // TODO: Add path parsing
+ continue
+ }
+ c.Unparsed = append(c.Unparsed, parts[i])
+ }
+ cookies = append(cookies, c)
+ }
+ return cookies
+}
+
+// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers.
+func SetCookie(w ResponseWriter, cookie *Cookie) {
+ w.Header().Add("Set-Cookie", cookie.String())
+}
+
+// String returns the serialization of the cookie for use in a Cookie
+// header (if only Name and Value are set) or a Set-Cookie response
+// header (if other fields are set).
+func (c *Cookie) String() string {
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value))
+ if len(c.Path) > 0 {
+ fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path))
+ }
+ if len(c.Domain) > 0 {
+ fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain))
+ }
+ if c.Expires.Unix() > 0 {
+ fmt.Fprintf(&b, "; Expires=%s", c.Expires.UTC().Format(time.RFC1123))
+ }
+ if c.MaxAge > 0 {
+ fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge)
+ } else if c.MaxAge < 0 {
+ fmt.Fprintf(&b, "; Max-Age=0")
+ }
+ if c.HttpOnly {
+ fmt.Fprintf(&b, "; HttpOnly")
+ }
+ if c.Secure {
+ fmt.Fprintf(&b, "; Secure")
+ }
+ return b.String()
+}
+
+// readCookies parses all "Cookie" values from the header h and
+// returns the successfully parsed Cookies.
+//
+// if filter isn't empty, only cookies of that name are returned
+func readCookies(h Header, filter string) []*Cookie {
+ cookies := []*Cookie{}
+ lines, ok := h["Cookie"]
+ if !ok {
+ return cookies
+ }
+
+ for _, line := range lines {
+ parts := strings.Split(strings.TrimSpace(line), ";")
+ if len(parts) == 1 && parts[0] == "" {
+ continue
+ }
+ // Per-line attributes
+ parsedPairs := 0
+ for i := 0; i < len(parts); i++ {
+ parts[i] = strings.TrimSpace(parts[i])
+ if len(parts[i]) == 0 {
+ continue
+ }
+ name, val := parts[i], ""
+ if j := strings.Index(name, "="); j >= 0 {
+ name, val = name[:j], name[j+1:]
+ }
+ if !isCookieNameValid(name) {
+ continue
+ }
+ if filter != "" && filter != name {
+ continue
+ }
+ val, success := parseCookieValue(val)
+ if !success {
+ continue
+ }
+ cookies = append(cookies, &Cookie{Name: name, Value: val})
+ parsedPairs++
+ }
+ }
+ return cookies
+}
+
+var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-")
+
+func sanitizeName(n string) string {
+ return cookieNameSanitizer.Replace(n)
+}
+
+var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ")
+
+func sanitizeValue(v string) string {
+ return cookieValueSanitizer.Replace(v)
+}
+
+func unquoteCookieValue(v string) string {
+ if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' {
+ return v[1 : len(v)-1]
+ }
+ return v
+}
+
+func isCookieByte(c byte) bool {
+ switch {
+ case c == 0x21, 0x23 <= c && c <= 0x2b, 0x2d <= c && c <= 0x3a,
+ 0x3c <= c && c <= 0x5b, 0x5d <= c && c <= 0x7e:
+ return true
+ }
+ return false
+}
+
+func isCookieExpiresByte(c byte) (ok bool) {
+ return isCookieByte(c) || c == ',' || c == ' '
+}
+
+func parseCookieValue(raw string) (string, bool) {
+ return parseCookieValueUsing(raw, isCookieByte)
+}
+
+func parseCookieExpiresValue(raw string) (string, bool) {
+ return parseCookieValueUsing(raw, isCookieExpiresByte)
+}
+
+func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool) {
+ raw = unquoteCookieValue(raw)
+ for i := 0; i < len(raw); i++ {
+ if !validByte(raw[i]) {
+ return "", false
+ }
+ }
+ return raw, true
+}
+
+func isCookieNameValid(raw string) bool {
+ for _, c := range raw {
+ if !isToken(byte(c)) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/src/pkg/net/http/cookie_test.go b/src/pkg/net/http/cookie_test.go
new file mode 100644
index 000000000..712350dfc
--- /dev/null
+++ b/src/pkg/net/http/cookie_test.go
@@ -0,0 +1,200 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "encoding/json"
+ "fmt"
+ "reflect"
+ "testing"
+ "time"
+)
+
+var writeSetCookiesTests = []struct {
+ Cookie *Cookie
+ Raw string
+}{
+ {
+ &Cookie{Name: "cookie-1", Value: "v$1"},
+ "cookie-1=v$1",
+ },
+ {
+ &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600},
+ "cookie-2=two; Max-Age=3600",
+ },
+ {
+ &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"},
+ "cookie-3=three; Domain=.example.com",
+ },
+ {
+ &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"},
+ "cookie-4=four; Path=/restricted/",
+ },
+}
+
+func TestWriteSetCookies(t *testing.T) {
+ for i, tt := range writeSetCookiesTests {
+ if g, e := tt.Cookie.String(), tt.Raw; g != e {
+ t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g)
+ continue
+ }
+ }
+}
+
+type headerOnlyResponseWriter Header
+
+func (ho headerOnlyResponseWriter) Header() Header {
+ return Header(ho)
+}
+
+func (ho headerOnlyResponseWriter) Write([]byte) (int, error) {
+ panic("NOIMPL")
+}
+
+func (ho headerOnlyResponseWriter) WriteHeader(int) {
+ panic("NOIMPL")
+}
+
+func TestSetCookie(t *testing.T) {
+ m := make(Header)
+ SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-1", Value: "one", Path: "/restricted/"})
+ SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600})
+ if l := len(m["Set-Cookie"]); l != 2 {
+ t.Fatalf("expected %d cookies, got %d", 2, l)
+ }
+ if g, e := m["Set-Cookie"][0], "cookie-1=one; Path=/restricted/"; g != e {
+ t.Errorf("cookie #1: want %q, got %q", e, g)
+ }
+ if g, e := m["Set-Cookie"][1], "cookie-2=two; Max-Age=3600"; g != e {
+ t.Errorf("cookie #2: want %q, got %q", e, g)
+ }
+}
+
+var addCookieTests = []struct {
+ Cookies []*Cookie
+ Raw string
+}{
+ {
+ []*Cookie{},
+ "",
+ },
+ {
+ []*Cookie{{Name: "cookie-1", Value: "v$1"}},
+ "cookie-1=v$1",
+ },
+ {
+ []*Cookie{
+ {Name: "cookie-1", Value: "v$1"},
+ {Name: "cookie-2", Value: "v$2"},
+ {Name: "cookie-3", Value: "v$3"},
+ },
+ "cookie-1=v$1; cookie-2=v$2; cookie-3=v$3",
+ },
+}
+
+func TestAddCookie(t *testing.T) {
+ for i, tt := range addCookieTests {
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+ for _, c := range tt.Cookies {
+ req.AddCookie(c)
+ }
+ if g := req.Header.Get("Cookie"); g != tt.Raw {
+ t.Errorf("Test %d:\nwant: %s\n got: %s\n", i, tt.Raw, g)
+ continue
+ }
+ }
+}
+
+var readSetCookiesTests = []struct {
+ Header Header
+ Cookies []*Cookie
+}{
+ {
+ Header{"Set-Cookie": {"Cookie-1=v$1"}},
+ []*Cookie{{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}},
+ },
+ {
+ Header{"Set-Cookie": {"NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly"}},
+ []*Cookie{{
+ Name: "NID",
+ Value: "99=YsDT5i3E-CXax-",
+ Path: "/",
+ Domain: ".google.ch",
+ HttpOnly: true,
+ Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC),
+ RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT",
+ Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly",
+ }},
+ },
+}
+
+func toJSON(v interface{}) string {
+ b, err := json.Marshal(v)
+ if err != nil {
+ return fmt.Sprintf("%#v", v)
+ }
+ return string(b)
+}
+
+func TestReadSetCookies(t *testing.T) {
+ for i, tt := range readSetCookiesTests {
+ for n := 0; n < 2; n++ { // to verify readSetCookies doesn't mutate its input
+ c := readSetCookies(tt.Header)
+ if !reflect.DeepEqual(c, tt.Cookies) {
+ t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies))
+ continue
+ }
+ }
+ }
+}
+
+var readCookiesTests = []struct {
+ Header Header
+ Filter string
+ Cookies []*Cookie
+}{
+ {
+ Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}},
+ "c2",
+ []*Cookie{
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {"Cookie-1=v$1; c2=v2"}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {"Cookie-1=v$1; c2=v2"}},
+ "c2",
+ []*Cookie{
+ {Name: "c2", Value: "v2"},
+ },
+ },
+}
+
+func TestReadCookies(t *testing.T) {
+ for i, tt := range readCookiesTests {
+ for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input
+ c := readCookies(tt.Header, tt.Filter)
+ if !reflect.DeepEqual(c, tt.Cookies) {
+ t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies))
+ continue
+ }
+ }
+ }
+}
diff --git a/src/pkg/net/http/doc.go b/src/pkg/net/http/doc.go
new file mode 100644
index 000000000..8962ed31e
--- /dev/null
+++ b/src/pkg/net/http/doc.go
@@ -0,0 +1,80 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package http provides HTTP client and server implementations.
+
+Get, Head, Post, and PostForm make HTTP requests:
+
+ resp, err := http.Get("http://example.com/")
+ ...
+ resp, err := http.Post("http://example.com/upload", "image/jpeg", &buf)
+ ...
+ resp, err := http.PostForm("http://example.com/form",
+ url.Values{"key": {"Value"}, "id": {"123"}})
+
+The client must close the response body when finished with it:
+
+ resp, err := http.Get("http://example.com/")
+ if err != nil {
+ // handle error
+ }
+ defer resp.Body.Close()
+ body, err := ioutil.ReadAll(resp.Body)
+ // ...
+
+For control over HTTP client headers, redirect policy, and other
+settings, create a Client:
+
+ client := &http.Client{
+ CheckRedirect: redirectPolicyFunc,
+ }
+
+ resp, err := client.Get("http://example.com")
+ // ...
+
+ req, err := http.NewRequest("GET", "http://example.com", nil)
+ // ...
+ req.Header.Add("If-None-Match", `W/"wyzzy"`)
+ resp, err := client.Do(req)
+ // ...
+
+For control over proxies, TLS configuration, keep-alives,
+compression, and other settings, create a Transport:
+
+ tr := &http.Transport{
+ TLSClientConfig: &tls.Config{RootCAs: pool},
+ DisableCompression: true,
+ }
+ client := &http.Client{Transport: tr}
+ resp, err := client.Get("https://example.com")
+
+Clients and Transports are safe for concurrent use by multiple
+goroutines and for efficiency should only be created once and re-used.
+
+ListenAndServe starts an HTTP server with a given address and handler.
+The handler is usually nil, which means to use DefaultServeMux.
+Handle and HandleFunc add handlers to DefaultServeMux:
+
+ http.Handle("/foo", fooHandler)
+
+ http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.RawPath))
+ })
+
+ log.Fatal(http.ListenAndServe(":8080", nil))
+
+More control over the server's behavior is available by creating a
+custom Server:
+
+ s := &http.Server{
+ Addr: ":8080",
+ Handler: myHandler,
+ ReadTimeout: 10 * time.Second,
+ WriteTimeout: 10 * time.Second,
+ MaxHeaderBytes: 1 << 20,
+ }
+ log.Fatal(s.ListenAndServe())
+*/
+package http
diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go
new file mode 100644
index 000000000..13640ca85
--- /dev/null
+++ b/src/pkg/net/http/export_test.go
@@ -0,0 +1,43 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Bridge package to expose http internals to tests in the http_test
+// package.
+
+package http
+
+import "time"
+
+func (t *Transport) IdleConnKeysForTesting() (keys []string) {
+ keys = make([]string, 0)
+ t.lk.Lock()
+ defer t.lk.Unlock()
+ if t.idleConn == nil {
+ return
+ }
+ for key := range t.idleConn {
+ keys = append(keys, key)
+ }
+ return
+}
+
+func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
+ t.lk.Lock()
+ defer t.lk.Unlock()
+ if t.idleConn == nil {
+ return 0
+ }
+ conns, ok := t.idleConn[cacheKey]
+ if !ok {
+ return 0
+ }
+ return len(conns)
+}
+
+func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
+ f := func() <-chan time.Time {
+ return ch
+ }
+ return &timeoutHandler{handler, f, ""}
+}
diff --git a/src/pkg/net/http/fcgi/Makefile b/src/pkg/net/http/fcgi/Makefile
new file mode 100644
index 000000000..9a75f1a80
--- /dev/null
+++ b/src/pkg/net/http/fcgi/Makefile
@@ -0,0 +1,12 @@
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../../Make.inc
+
+TARG=net/http/fcgi
+GOFILES=\
+ child.go\
+ fcgi.go\
+
+include ../../../../Make.pkg
diff --git a/src/pkg/net/http/fcgi/child.go b/src/pkg/net/http/fcgi/child.go
new file mode 100644
index 000000000..c94b9a7b2
--- /dev/null
+++ b/src/pkg/net/http/fcgi/child.go
@@ -0,0 +1,271 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fcgi
+
+// This file implements FastCGI from the perspective of a child process.
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/http/cgi"
+ "os"
+ "time"
+)
+
+// request holds the state for an in-progress request. As soon as it's complete,
+// it's converted to an http.Request.
+type request struct {
+ pw *io.PipeWriter
+ reqId uint16
+ params map[string]string
+ buf [1024]byte
+ rawParams []byte
+ keepConn bool
+}
+
+func newRequest(reqId uint16, flags uint8) *request {
+ r := &request{
+ reqId: reqId,
+ params: map[string]string{},
+ keepConn: flags&flagKeepConn != 0,
+ }
+ r.rawParams = r.buf[:0]
+ return r
+}
+
+// parseParams reads an encoded []byte into Params.
+func (r *request) parseParams() {
+ text := r.rawParams
+ r.rawParams = nil
+ for len(text) > 0 {
+ keyLen, n := readSize(text)
+ if n == 0 {
+ return
+ }
+ text = text[n:]
+ valLen, n := readSize(text)
+ if n == 0 {
+ return
+ }
+ text = text[n:]
+ key := readString(text, keyLen)
+ text = text[keyLen:]
+ val := readString(text, valLen)
+ text = text[valLen:]
+ r.params[key] = val
+ }
+}
+
+// response implements http.ResponseWriter.
+type response struct {
+ req *request
+ header http.Header
+ w *bufWriter
+ wroteHeader bool
+}
+
+func newResponse(c *child, req *request) *response {
+ return &response{
+ req: req,
+ header: http.Header{},
+ w: newWriter(c.conn, typeStdout, req.reqId),
+ }
+}
+
+func (r *response) Header() http.Header {
+ return r.header
+}
+
+func (r *response) Write(data []byte) (int, error) {
+ if !r.wroteHeader {
+ r.WriteHeader(http.StatusOK)
+ }
+ return r.w.Write(data)
+}
+
+func (r *response) WriteHeader(code int) {
+ if r.wroteHeader {
+ return
+ }
+ r.wroteHeader = true
+ if code == http.StatusNotModified {
+ // Must not have body.
+ r.header.Del("Content-Type")
+ r.header.Del("Content-Length")
+ r.header.Del("Transfer-Encoding")
+ } else if r.header.Get("Content-Type") == "" {
+ r.header.Set("Content-Type", "text/html; charset=utf-8")
+ }
+
+ if r.header.Get("Date") == "" {
+ r.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
+ }
+
+ fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code))
+ r.header.Write(r.w)
+ r.w.WriteString("\r\n")
+}
+
+func (r *response) Flush() {
+ if !r.wroteHeader {
+ r.WriteHeader(http.StatusOK)
+ }
+ r.w.Flush()
+}
+
+func (r *response) Close() error {
+ r.Flush()
+ return r.w.Close()
+}
+
+type child struct {
+ conn *conn
+ handler http.Handler
+ requests map[uint16]*request // keyed by request ID
+}
+
+func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child {
+ return &child{
+ conn: newConn(rwc),
+ handler: handler,
+ requests: make(map[uint16]*request),
+ }
+}
+
+func (c *child) serve() {
+ defer c.conn.Close()
+ var rec record
+ for {
+ if err := rec.read(c.conn.rwc); err != nil {
+ return
+ }
+ if err := c.handleRecord(&rec); err != nil {
+ return
+ }
+ }
+}
+
+var errCloseConn = errors.New("fcgi: connection should be closed")
+
+func (c *child) handleRecord(rec *record) error {
+ req, ok := c.requests[rec.h.Id]
+ if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
+ // The spec says to ignore unknown request IDs.
+ return nil
+ }
+ if ok && rec.h.Type == typeBeginRequest {
+ // The server is trying to begin a request with the same ID
+ // as an in-progress request. This is an error.
+ return errors.New("fcgi: received ID that is already in-flight")
+ }
+
+ switch rec.h.Type {
+ case typeBeginRequest:
+ var br beginRequest
+ if err := br.read(rec.content()); err != nil {
+ return err
+ }
+ if br.role != roleResponder {
+ c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
+ return nil
+ }
+ c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags)
+ case typeParams:
+ // NOTE(eds): Technically a key-value pair can straddle the boundary
+ // between two packets. We buffer until we've received all parameters.
+ if len(rec.content()) > 0 {
+ req.rawParams = append(req.rawParams, rec.content()...)
+ return nil
+ }
+ req.parseParams()
+ case typeStdin:
+ content := rec.content()
+ if req.pw == nil {
+ var body io.ReadCloser
+ if len(content) > 0 {
+ // body could be an io.LimitReader, but it shouldn't matter
+ // as long as both sides are behaving.
+ body, req.pw = io.Pipe()
+ }
+ go c.serveRequest(req, body)
+ }
+ if len(content) > 0 {
+ // TODO(eds): This blocks until the handler reads from the pipe.
+ // If the handler takes a long time, it might be a problem.
+ req.pw.Write(content)
+ } else if req.pw != nil {
+ req.pw.Close()
+ }
+ case typeGetValues:
+ values := map[string]string{"FCGI_MPXS_CONNS": "1"}
+ c.conn.writePairs(typeGetValuesResult, 0, values)
+ case typeData:
+ // If the filter role is implemented, read the data stream here.
+ case typeAbortRequest:
+ delete(c.requests, rec.h.Id)
+ c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
+ if !req.keepConn {
+ // connection will close upon return
+ return errCloseConn
+ }
+ default:
+ b := make([]byte, 8)
+ b[0] = byte(rec.h.Type)
+ c.conn.writeRecord(typeUnknownType, 0, b)
+ }
+ return nil
+}
+
+func (c *child) serveRequest(req *request, body io.ReadCloser) {
+ r := newResponse(c, req)
+ httpReq, err := cgi.RequestFromMap(req.params)
+ if err != nil {
+ // there was an error reading the request
+ r.WriteHeader(http.StatusInternalServerError)
+ c.conn.writeRecord(typeStderr, req.reqId, []byte(err.Error()))
+ } else {
+ httpReq.Body = body
+ c.handler.ServeHTTP(r, httpReq)
+ }
+ if body != nil {
+ body.Close()
+ }
+ r.Close()
+ c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete)
+ if !req.keepConn {
+ c.conn.Close()
+ }
+}
+
+// Serve accepts incoming FastCGI connections on the listener l, creating a new
+// service thread for each. The service threads read requests and then call handler
+// to reply to them.
+// If l is nil, Serve accepts connections on stdin.
+// If handler is nil, http.DefaultServeMux is used.
+func Serve(l net.Listener, handler http.Handler) error {
+ if l == nil {
+ var err error
+ l, err = net.FileListener(os.Stdin)
+ if err != nil {
+ return err
+ }
+ defer l.Close()
+ }
+ if handler == nil {
+ handler = http.DefaultServeMux
+ }
+ for {
+ rw, err := l.Accept()
+ if err != nil {
+ return err
+ }
+ c := newChild(rw, handler)
+ go c.serve()
+ }
+ panic("unreachable")
+}
diff --git a/src/pkg/net/http/fcgi/fcgi.go b/src/pkg/net/http/fcgi/fcgi.go
new file mode 100644
index 000000000..d35aa84d2
--- /dev/null
+++ b/src/pkg/net/http/fcgi/fcgi.go
@@ -0,0 +1,274 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fcgi implements the FastCGI protocol.
+// Currently only the responder role is supported.
+// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22
+package fcgi
+
+// This file defines the raw protocol and some utilities used by the child and
+// the host.
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "sync"
+)
+
+// recType is a record type, as defined by
+// http://www.fastcgi.com/devkit/doc/fcgi-spec.html#S8
+type recType uint8
+
+const (
+ typeBeginRequest recType = 1
+ typeAbortRequest recType = 2
+ typeEndRequest recType = 3
+ typeParams recType = 4
+ typeStdin recType = 5
+ typeStdout recType = 6
+ typeStderr recType = 7
+ typeData recType = 8
+ typeGetValues recType = 9
+ typeGetValuesResult recType = 10
+ typeUnknownType recType = 11
+)
+
+// keep the connection between web-server and responder open after request
+const flagKeepConn = 1
+
+const (
+ maxWrite = 65535 // maximum record body
+ maxPad = 255
+)
+
+const (
+ roleResponder = iota + 1 // only Responders are implemented.
+ roleAuthorizer
+ roleFilter
+)
+
+const (
+ statusRequestComplete = iota
+ statusCantMultiplex
+ statusOverloaded
+ statusUnknownRole
+)
+
+const headerLen = 8
+
+type header struct {
+ Version uint8
+ Type recType
+ Id uint16
+ ContentLength uint16
+ PaddingLength uint8
+ Reserved uint8
+}
+
+type beginRequest struct {
+ role uint16
+ flags uint8
+ reserved [5]uint8
+}
+
+func (br *beginRequest) read(content []byte) error {
+ if len(content) != 8 {
+ return errors.New("fcgi: invalid begin request record")
+ }
+ br.role = binary.BigEndian.Uint16(content)
+ br.flags = content[2]
+ return nil
+}
+
+// for padding so we don't have to allocate all the time
+// not synchronized because we don't care what the contents are
+var pad [maxPad]byte
+
+func (h *header) init(recType recType, reqId uint16, contentLength int) {
+ h.Version = 1
+ h.Type = recType
+ h.Id = reqId
+ h.ContentLength = uint16(contentLength)
+ h.PaddingLength = uint8(-contentLength & 7)
+}
+
+// conn sends records over rwc
+type conn struct {
+ mutex sync.Mutex
+ rwc io.ReadWriteCloser
+
+ // to avoid allocations
+ buf bytes.Buffer
+ h header
+}
+
+func newConn(rwc io.ReadWriteCloser) *conn {
+ return &conn{rwc: rwc}
+}
+
+func (c *conn) Close() error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ return c.rwc.Close()
+}
+
+type record struct {
+ h header
+ buf [maxWrite + maxPad]byte
+}
+
+func (rec *record) read(r io.Reader) (err error) {
+ if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil {
+ return err
+ }
+ if rec.h.Version != 1 {
+ return errors.New("fcgi: invalid header version")
+ }
+ n := int(rec.h.ContentLength) + int(rec.h.PaddingLength)
+ if _, err = io.ReadFull(r, rec.buf[:n]); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (r *record) content() []byte {
+ return r.buf[:r.h.ContentLength]
+}
+
+// writeRecord writes and sends a single record.
+func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ c.buf.Reset()
+ c.h.init(recType, reqId, len(b))
+ if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(b); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
+ return err
+ }
+ _, err := c.rwc.Write(c.buf.Bytes())
+ return err
+}
+
+func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) error {
+ b := [8]byte{byte(role >> 8), byte(role), flags}
+ return c.writeRecord(typeBeginRequest, reqId, b[:])
+}
+
+func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) error {
+ b := make([]byte, 8)
+ binary.BigEndian.PutUint32(b, uint32(appStatus))
+ b[4] = protocolStatus
+ return c.writeRecord(typeEndRequest, reqId, b)
+}
+
+func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error {
+ w := newWriter(c, recType, reqId)
+ b := make([]byte, 8)
+ for k, v := range pairs {
+ n := encodeSize(b, uint32(len(k)))
+ n += encodeSize(b[n:], uint32(len(v)))
+ if _, err := w.Write(b[:n]); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(k); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(v); err != nil {
+ return err
+ }
+ }
+ w.Close()
+ return nil
+}
+
+func readSize(s []byte) (uint32, int) {
+ if len(s) == 0 {
+ return 0, 0
+ }
+ size, n := uint32(s[0]), 1
+ if size&(1<<7) != 0 {
+ if len(s) < 4 {
+ return 0, 0
+ }
+ n = 4
+ size = binary.BigEndian.Uint32(s)
+ size &^= 1 << 31
+ }
+ return size, n
+}
+
+func readString(s []byte, size uint32) string {
+ if size > uint32(len(s)) {
+ return ""
+ }
+ return string(s[:size])
+}
+
+func encodeSize(b []byte, size uint32) int {
+ if size > 127 {
+ size |= 1 << 31
+ binary.BigEndian.PutUint32(b, size)
+ return 4
+ }
+ b[0] = byte(size)
+ return 1
+}
+
+// bufWriter encapsulates bufio.Writer but also closes the underlying stream when
+// Closed.
+type bufWriter struct {
+ closer io.Closer
+ *bufio.Writer
+}
+
+func (w *bufWriter) Close() error {
+ if err := w.Writer.Flush(); err != nil {
+ w.closer.Close()
+ return err
+ }
+ return w.closer.Close()
+}
+
+func newWriter(c *conn, recType recType, reqId uint16) *bufWriter {
+ s := &streamWriter{c: c, recType: recType, reqId: reqId}
+ w, _ := bufio.NewWriterSize(s, maxWrite)
+ return &bufWriter{s, w}
+}
+
+// streamWriter abstracts out the separation of a stream into discrete records.
+// It only writes maxWrite bytes at a time.
+type streamWriter struct {
+ c *conn
+ recType recType
+ reqId uint16
+}
+
+func (w *streamWriter) Write(p []byte) (int, error) {
+ nn := 0
+ for len(p) > 0 {
+ n := len(p)
+ if n > maxWrite {
+ n = maxWrite
+ }
+ if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil {
+ return nn, err
+ }
+ nn += n
+ p = p[n:]
+ }
+ return nn, nil
+}
+
+func (w *streamWriter) Close() error {
+ // send empty record to close the stream
+ return w.c.writeRecord(w.recType, w.reqId, nil)
+}
diff --git a/src/pkg/net/http/fcgi/fcgi_test.go b/src/pkg/net/http/fcgi/fcgi_test.go
new file mode 100644
index 000000000..6c7e1a9ce
--- /dev/null
+++ b/src/pkg/net/http/fcgi/fcgi_test.go
@@ -0,0 +1,150 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fcgi
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "testing"
+)
+
+var sizeTests = []struct {
+ size uint32
+ bytes []byte
+}{
+ {0, []byte{0x00}},
+ {127, []byte{0x7F}},
+ {128, []byte{0x80, 0x00, 0x00, 0x80}},
+ {1000, []byte{0x80, 0x00, 0x03, 0xE8}},
+ {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}},
+}
+
+func TestSize(t *testing.T) {
+ b := make([]byte, 4)
+ for i, test := range sizeTests {
+ n := encodeSize(b, test.size)
+ if !bytes.Equal(b[:n], test.bytes) {
+ t.Errorf("%d expected %x, encoded %x", i, test.bytes, b)
+ }
+ size, n := readSize(test.bytes)
+ if size != test.size {
+ t.Errorf("%d expected %d, read %d", i, test.size, size)
+ }
+ if len(test.bytes) != n {
+ t.Errorf("%d did not consume all the bytes", i)
+ }
+ }
+}
+
+var streamTests = []struct {
+ desc string
+ recType recType
+ reqId uint16
+ content []byte
+ raw []byte
+}{
+ {"single record", typeStdout, 1, nil,
+ []byte{1, byte(typeStdout), 0, 1, 0, 0, 0, 0},
+ },
+ // this data will have to be split into two records
+ {"two records", typeStdin, 300, make([]byte, 66000),
+ bytes.Join([][]byte{
+ // header for the first record
+ {1, byte(typeStdin), 0x01, 0x2C, 0xFF, 0xFF, 1, 0},
+ make([]byte, 65536),
+ // header for the second
+ {1, byte(typeStdin), 0x01, 0x2C, 0x01, 0xD1, 7, 0},
+ make([]byte, 472),
+ // header for the empty record
+ {1, byte(typeStdin), 0x01, 0x2C, 0, 0, 0, 0},
+ },
+ nil),
+ },
+}
+
+type nilCloser struct {
+ io.ReadWriter
+}
+
+func (c *nilCloser) Close() error { return nil }
+
+func TestStreams(t *testing.T) {
+ var rec record
+outer:
+ for _, test := range streamTests {
+ buf := bytes.NewBuffer(test.raw)
+ var content []byte
+ for buf.Len() > 0 {
+ if err := rec.read(buf); err != nil {
+ t.Errorf("%s: error reading record: %v", test.desc, err)
+ continue outer
+ }
+ content = append(content, rec.content()...)
+ }
+ if rec.h.Type != test.recType {
+ t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType)
+ continue
+ }
+ if rec.h.Id != test.reqId {
+ t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId)
+ continue
+ }
+ if !bytes.Equal(content, test.content) {
+ t.Errorf("%s: read wrong content", test.desc)
+ continue
+ }
+ buf.Reset()
+ c := newConn(&nilCloser{buf})
+ w := newWriter(c, test.recType, test.reqId)
+ if _, err := w.Write(test.content); err != nil {
+ t.Errorf("%s: error writing record: %v", test.desc, err)
+ continue
+ }
+ if err := w.Close(); err != nil {
+ t.Errorf("%s: error closing stream: %v", test.desc, err)
+ continue
+ }
+ if !bytes.Equal(buf.Bytes(), test.raw) {
+ t.Errorf("%s: wrote wrong content", test.desc)
+ }
+ }
+}
+
+type writeOnlyConn struct {
+ buf []byte
+}
+
+func (c *writeOnlyConn) Write(p []byte) (int, error) {
+ c.buf = append(c.buf, p...)
+ return len(p), nil
+}
+
+func (c *writeOnlyConn) Read(p []byte) (int, error) {
+ return 0, errors.New("conn is write-only")
+}
+
+func (c *writeOnlyConn) Close() error {
+ return nil
+}
+
+func TestGetValues(t *testing.T) {
+ var rec record
+ rec.h.Type = typeGetValues
+
+ wc := new(writeOnlyConn)
+ c := newChild(wc, nil)
+ err := c.handleRecord(&rec)
+ if err != nil {
+ t.Fatalf("handleRecord: %v", err)
+ }
+
+ const want = "\x01\n\x00\x00\x00\x12\x06\x00" +
+ "\x0f\x01FCGI_MPXS_CONNS1" +
+ "\x00\x00\x00\x00\x00\x00\x01\n\x00\x00\x00\x00\x00\x00"
+ if got := string(wc.buf); got != want {
+ t.Errorf(" got: %q\nwant: %q\n", got, want)
+ }
+}
diff --git a/src/pkg/net/http/filetransport.go b/src/pkg/net/http/filetransport.go
new file mode 100644
index 000000000..821787e0c
--- /dev/null
+++ b/src/pkg/net/http/filetransport.go
@@ -0,0 +1,123 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "fmt"
+ "io"
+)
+
+// fileTransport implements RoundTripper for the 'file' protocol.
+type fileTransport struct {
+ fh fileHandler
+}
+
+// NewFileTransport returns a new RoundTripper, serving the provided
+// FileSystem. The returned RoundTripper ignores the URL host in its
+// incoming requests, as well as most other properties of the
+// request.
+//
+// The typical use case for NewFileTransport is to register the "file"
+// protocol with a Transport, as in:
+//
+// t := &http.Transport{}
+// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
+// c := &http.Client{Transport: t}
+// res, err := c.Get("file:///etc/passwd")
+// ...
+func NewFileTransport(fs FileSystem) RoundTripper {
+ return fileTransport{fileHandler{fs}}
+}
+
+func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) {
+ // We start ServeHTTP in a goroutine, which may take a long
+ // time if the file is large. The newPopulateResponseWriter
+ // call returns a channel which either ServeHTTP or finish()
+ // sends our *Response on, once the *Response itself has been
+ // populated (even if the body itself is still being
+ // written to the res.Body, a pipe)
+ rw, resc := newPopulateResponseWriter()
+ go func() {
+ t.fh.ServeHTTP(rw, req)
+ rw.finish()
+ }()
+ return <-resc, nil
+}
+
+func newPopulateResponseWriter() (*populateResponse, <-chan *Response) {
+ pr, pw := io.Pipe()
+ rw := &populateResponse{
+ ch: make(chan *Response),
+ pw: pw,
+ res: &Response{
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ Header: make(Header),
+ Close: true,
+ Body: pr,
+ },
+ }
+ return rw, rw.ch
+}
+
+// populateResponse is a ResponseWriter that populates the *Response
+// in res, and writes its body to a pipe connected to the response
+// body. Once writes begin or finish() is called, the response is sent
+// on ch.
+type populateResponse struct {
+ res *Response
+ ch chan *Response
+ wroteHeader bool
+ hasContent bool
+ sentResponse bool
+ pw *io.PipeWriter
+}
+
+func (pr *populateResponse) finish() {
+ if !pr.wroteHeader {
+ pr.WriteHeader(500)
+ }
+ if !pr.sentResponse {
+ pr.sendResponse()
+ }
+ pr.pw.Close()
+}
+
+func (pr *populateResponse) sendResponse() {
+ if pr.sentResponse {
+ return
+ }
+ pr.sentResponse = true
+
+ if pr.hasContent {
+ pr.res.ContentLength = -1
+ }
+ pr.ch <- pr.res
+}
+
+func (pr *populateResponse) Header() Header {
+ return pr.res.Header
+}
+
+func (pr *populateResponse) WriteHeader(code int) {
+ if pr.wroteHeader {
+ return
+ }
+ pr.wroteHeader = true
+
+ pr.res.StatusCode = code
+ pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code))
+}
+
+func (pr *populateResponse) Write(p []byte) (n int, err error) {
+ if !pr.wroteHeader {
+ pr.WriteHeader(StatusOK)
+ }
+ pr.hasContent = true
+ if !pr.sentResponse {
+ pr.sendResponse()
+ }
+ return pr.pw.Write(p)
+}
diff --git a/src/pkg/net/http/filetransport_test.go b/src/pkg/net/http/filetransport_test.go
new file mode 100644
index 000000000..039926b53
--- /dev/null
+++ b/src/pkg/net/http/filetransport_test.go
@@ -0,0 +1,65 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "io/ioutil"
+ "net/http"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func checker(t *testing.T) func(string, error) {
+ return func(call string, err error) {
+ if err == nil {
+ return
+ }
+ t.Fatalf("%s: %v", call, err)
+ }
+}
+
+func TestFileTransport(t *testing.T) {
+ check := checker(t)
+
+ dname, err := ioutil.TempDir("", "")
+ check("TempDir", err)
+ fname := filepath.Join(dname, "foo.txt")
+ err = ioutil.WriteFile(fname, []byte("Bar"), 0644)
+ check("WriteFile", err)
+ defer os.Remove(dname)
+ defer os.Remove(fname)
+
+ tr := &http.Transport{}
+ tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname)))
+ c := &http.Client{Transport: tr}
+
+ fooURLs := []string{"file:///foo.txt", "file://../foo.txt"}
+ for _, urlstr := range fooURLs {
+ res, err := c.Get(urlstr)
+ check("Get "+urlstr, err)
+ if res.StatusCode != 200 {
+ t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode)
+ }
+ if res.ContentLength != -1 {
+ t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength)
+ }
+ if res.Body == nil {
+ t.Fatalf("for %s, nil Body", urlstr)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ check("ReadAll "+urlstr, err)
+ if string(slurp) != "Bar" {
+ t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar")
+ }
+ }
+
+ const badURL = "file://../no-exist.txt"
+ res, err := c.Get(badURL)
+ check("Get "+badURL, err)
+ if res.StatusCode != 404 {
+ t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode)
+ }
+}
diff --git a/src/pkg/net/http/fs.go b/src/pkg/net/http/fs.go
new file mode 100644
index 000000000..1392ca68a
--- /dev/null
+++ b/src/pkg/net/http/fs.go
@@ -0,0 +1,330 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP file system request handler
+
+package http
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "mime"
+ "os"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+ "unicode/utf8"
+)
+
+// A Dir implements http.FileSystem using the native file
+// system restricted to a specific directory tree.
+//
+// An empty Dir is treated as ".".
+type Dir string
+
+func (d Dir) Open(name string) (File, error) {
+ if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 {
+ return nil, errors.New("http: invalid character in file path")
+ }
+ dir := string(d)
+ if dir == "" {
+ dir = "."
+ }
+ f, err := os.Open(filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))))
+ if err != nil {
+ return nil, err
+ }
+ return f, nil
+}
+
+// A FileSystem implements access to a collection of named files.
+// The elements in a file path are separated by slash ('/', U+002F)
+// characters, regardless of host operating system convention.
+type FileSystem interface {
+ Open(name string) (File, error)
+}
+
+// A File is returned by a FileSystem's Open method and can be
+// served by the FileServer implementation.
+type File interface {
+ Close() error
+ Stat() (os.FileInfo, error)
+ Readdir(count int) ([]os.FileInfo, error)
+ Read([]byte) (int, error)
+ Seek(offset int64, whence int) (int64, error)
+}
+
+// Heuristic: b is text if it is valid UTF-8 and doesn't
+// contain any unprintable ASCII or Unicode characters.
+func isText(b []byte) bool {
+ for len(b) > 0 && utf8.FullRune(b) {
+ rune, size := utf8.DecodeRune(b)
+ if size == 1 && rune == utf8.RuneError {
+ // decoding error
+ return false
+ }
+ if 0x7F <= rune && rune <= 0x9F {
+ return false
+ }
+ if rune < ' ' {
+ switch rune {
+ case '\n', '\r', '\t':
+ // okay
+ default:
+ // binary garbage
+ return false
+ }
+ }
+ b = b[size:]
+ }
+ return true
+}
+
+func dirList(w ResponseWriter, f File) {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ fmt.Fprintf(w, "<pre>\n")
+ for {
+ dirs, err := f.Readdir(100)
+ if err != nil || len(dirs) == 0 {
+ break
+ }
+ for _, d := range dirs {
+ name := d.Name()
+ if d.IsDir() {
+ name += "/"
+ }
+ // TODO htmlescape
+ fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", name, name)
+ }
+ }
+ fmt.Fprintf(w, "</pre>\n")
+}
+
+// name is '/'-separated, not filepath.Separator.
+func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) {
+ const indexPage = "/index.html"
+
+ // redirect .../index.html to .../
+ // can't use Redirect() because that would make the path absolute,
+ // which would be a problem running under StripPrefix
+ if strings.HasSuffix(r.URL.Path, indexPage) {
+ localRedirect(w, r, "./")
+ return
+ }
+
+ f, err := fs.Open(name)
+ if err != nil {
+ // TODO expose actual error?
+ NotFound(w, r)
+ return
+ }
+ defer f.Close()
+
+ d, err1 := f.Stat()
+ if err1 != nil {
+ // TODO expose actual error?
+ NotFound(w, r)
+ return
+ }
+
+ if redirect {
+ // redirect to canonical path: / at end of directory url
+ // r.URL.Path always begins with /
+ url := r.URL.Path
+ if d.IsDir() {
+ if url[len(url)-1] != '/' {
+ localRedirect(w, r, path.Base(url)+"/")
+ return
+ }
+ } else {
+ if url[len(url)-1] == '/' {
+ localRedirect(w, r, "../"+path.Base(url))
+ return
+ }
+ }
+ }
+
+ if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && !d.ModTime().After(t) {
+ w.WriteHeader(StatusNotModified)
+ return
+ }
+ w.Header().Set("Last-Modified", d.ModTime().UTC().Format(TimeFormat))
+
+ // use contents of index.html for directory, if present
+ if d.IsDir() {
+ index := name + indexPage
+ ff, err := fs.Open(index)
+ if err == nil {
+ defer ff.Close()
+ dd, err := ff.Stat()
+ if err == nil {
+ name = index
+ d = dd
+ f = ff
+ }
+ }
+ }
+
+ if d.IsDir() {
+ dirList(w, f)
+ return
+ }
+
+ // serve file
+ size := d.Size()
+ code := StatusOK
+
+ // If Content-Type isn't set, use the file's extension to find it.
+ if w.Header().Get("Content-Type") == "" {
+ ctype := mime.TypeByExtension(filepath.Ext(name))
+ if ctype == "" {
+ // read a chunk to decide between utf-8 text and binary
+ var buf [1024]byte
+ n, _ := io.ReadFull(f, buf[:])
+ b := buf[:n]
+ if isText(b) {
+ ctype = "text/plain; charset=utf-8"
+ } else {
+ // generic binary
+ ctype = "application/octet-stream"
+ }
+ f.Seek(0, os.SEEK_SET) // rewind to output whole file
+ }
+ w.Header().Set("Content-Type", ctype)
+ }
+
+ // handle Content-Range header.
+ // TODO(adg): handle multiple ranges
+ ranges, err := parseRange(r.Header.Get("Range"), size)
+ if err == nil && len(ranges) > 1 {
+ err = errors.New("multiple ranges not supported")
+ }
+ if err != nil {
+ Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ return
+ }
+ if len(ranges) == 1 {
+ ra := ranges[0]
+ if _, err := f.Seek(ra.start, os.SEEK_SET); err != nil {
+ Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ return
+ }
+ size = ra.length
+ code = StatusPartialContent
+ w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, d.Size()))
+ }
+
+ w.Header().Set("Accept-Ranges", "bytes")
+ if w.Header().Get("Content-Encoding") == "" {
+ w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
+ }
+
+ w.WriteHeader(code)
+
+ if r.Method != "HEAD" {
+ io.CopyN(w, f, size)
+ }
+}
+
+// localRedirect gives a Moved Permanently response.
+// It does not convert relative paths to absolute paths like Redirect does.
+func localRedirect(w ResponseWriter, r *Request, newPath string) {
+ if q := r.URL.RawQuery; q != "" {
+ newPath += "?" + q
+ }
+ w.Header().Set("Location", newPath)
+ w.WriteHeader(StatusMovedPermanently)
+}
+
+// ServeFile replies to the request with the contents of the named file or directory.
+func ServeFile(w ResponseWriter, r *Request, name string) {
+ dir, file := filepath.Split(name)
+ serveFile(w, r, Dir(dir), file, false)
+}
+
+type fileHandler struct {
+ root FileSystem
+}
+
+// FileServer returns a handler that serves HTTP requests
+// with the contents of the file system rooted at root.
+//
+// To use the operating system's file system implementation,
+// use http.Dir:
+//
+// http.Handle("/", http.FileServer(http.Dir("/tmp")))
+func FileServer(root FileSystem) Handler {
+ return &fileHandler{root}
+}
+
+func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ upath := r.URL.Path
+ if !strings.HasPrefix(upath, "/") {
+ upath = "/" + upath
+ r.URL.Path = upath
+ }
+ serveFile(w, r, f.root, path.Clean(upath), true)
+}
+
+// httpRange specifies the byte range to be sent to the client.
+type httpRange struct {
+ start, length int64
+}
+
+// parseRange parses a Range header string as per RFC 2616.
+func parseRange(s string, size int64) ([]httpRange, error) {
+ if s == "" {
+ return nil, nil // header not present
+ }
+ const b = "bytes="
+ if !strings.HasPrefix(s, b) {
+ return nil, errors.New("invalid range")
+ }
+ var ranges []httpRange
+ for _, ra := range strings.Split(s[len(b):], ",") {
+ i := strings.Index(ra, "-")
+ if i < 0 {
+ return nil, errors.New("invalid range")
+ }
+ start, end := ra[:i], ra[i+1:]
+ var r httpRange
+ if start == "" {
+ // If no start is specified, end specifies the
+ // range start relative to the end of the file.
+ i, err := strconv.ParseInt(end, 10, 64)
+ if err != nil {
+ return nil, errors.New("invalid range")
+ }
+ if i > size {
+ i = size
+ }
+ r.start = size - i
+ r.length = size - r.start
+ } else {
+ i, err := strconv.ParseInt(start, 10, 64)
+ if err != nil || i > size || i < 0 {
+ return nil, errors.New("invalid range")
+ }
+ r.start = i
+ if end == "" {
+ // If no end is specified, range extends to end of the file.
+ r.length = size - r.start
+ } else {
+ i, err := strconv.ParseInt(end, 10, 64)
+ if err != nil || r.start > i {
+ return nil, errors.New("invalid range")
+ }
+ if i >= size {
+ i = size - 1
+ }
+ r.length = i - r.start + 1
+ }
+ }
+ ranges = append(ranges, r)
+ }
+ return ranges, nil
+}
diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go
new file mode 100644
index 000000000..85cad3ec7
--- /dev/null
+++ b/src/pkg/net/http/fs_test.go
@@ -0,0 +1,334 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "fmt"
+ "io/ioutil"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+const (
+ testFile = "testdata/file"
+ testFileLength = 11
+)
+
+var ServeFileRangeTests = []struct {
+ start, end int
+ r string
+ code int
+}{
+ {0, testFileLength, "", StatusOK},
+ {0, 5, "0-4", StatusPartialContent},
+ {2, testFileLength, "2-", StatusPartialContent},
+ {testFileLength - 5, testFileLength, "-5", StatusPartialContent},
+ {3, 8, "3-7", StatusPartialContent},
+ {0, 0, "20-", StatusRequestedRangeNotSatisfiable},
+}
+
+func TestServeFile(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "testdata/file")
+ }))
+ defer ts.Close()
+
+ var err error
+
+ file, err := ioutil.ReadFile(testFile)
+ if err != nil {
+ t.Fatal("reading file:", err)
+ }
+
+ // set up the Request (re-used for all tests)
+ var req Request
+ req.Header = make(Header)
+ if req.URL, err = url.Parse(ts.URL); err != nil {
+ t.Fatal("ParseURL:", err)
+ }
+ req.Method = "GET"
+
+ // straight GET
+ _, body := getBody(t, req)
+ if !equal(body, file) {
+ t.Fatalf("body mismatch: got %q, want %q", body, file)
+ }
+
+ // Range tests
+ for _, rt := range ServeFileRangeTests {
+ req.Header.Set("Range", "bytes="+rt.r)
+ if rt.r == "" {
+ req.Header["Range"] = nil
+ }
+ r, body := getBody(t, req)
+ if r.StatusCode != rt.code {
+ t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, r.StatusCode, rt.code)
+ }
+ if rt.code == StatusRequestedRangeNotSatisfiable {
+ continue
+ }
+ h := fmt.Sprintf("bytes %d-%d/%d", rt.start, rt.end-1, testFileLength)
+ if rt.r == "" {
+ h = ""
+ }
+ cr := r.Header.Get("Content-Range")
+ if cr != h {
+ t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h)
+ }
+ if !equal(body, file[rt.start:rt.end]) {
+ t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end])
+ }
+ }
+}
+
+var fsRedirectTestData = []struct {
+ original, redirect string
+}{
+ {"/test/index.html", "/test/"},
+ {"/test/testdata", "/test/testdata/"},
+ {"/test/testdata/file/", "/test/testdata/file"},
+}
+
+func TestFSRedirect(t *testing.T) {
+ ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir("."))))
+ defer ts.Close()
+
+ for _, data := range fsRedirectTestData {
+ res, err := Get(ts.URL + data.original)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if g, e := res.Request.URL.Path, data.redirect; g != e {
+ t.Errorf("redirect from %s: got %s, want %s", data.original, g, e)
+ }
+ }
+}
+
+type testFileSystem struct {
+ open func(name string) (File, error)
+}
+
+func (fs *testFileSystem) Open(name string) (File, error) {
+ return fs.open(name)
+}
+
+func TestFileServerCleans(t *testing.T) {
+ ch := make(chan string, 1)
+ fs := FileServer(&testFileSystem{func(name string) (File, error) {
+ ch <- name
+ return nil, os.ENOENT
+ }})
+ tests := []struct {
+ reqPath, openArg string
+ }{
+ {"/foo.txt", "/foo.txt"},
+ {"//foo.txt", "/foo.txt"},
+ {"/../foo.txt", "/foo.txt"},
+ }
+ req, _ := NewRequest("GET", "http://example.com", nil)
+ for n, test := range tests {
+ rec := httptest.NewRecorder()
+ req.URL.Path = test.reqPath
+ fs.ServeHTTP(rec, req)
+ if got := <-ch; got != test.openArg {
+ t.Errorf("test %d: got %q, want %q", n, got, test.openArg)
+ }
+ }
+}
+
+func TestFileServerImplicitLeadingSlash(t *testing.T) {
+ tempDir, err := ioutil.TempDir("", "")
+ if err != nil {
+ t.Fatalf("TempDir: %v", err)
+ }
+ defer os.RemoveAll(tempDir)
+ if err := ioutil.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil {
+ t.Fatalf("WriteFile: %v", err)
+ }
+ ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir))))
+ defer ts.Close()
+ get := func(suffix string) string {
+ res, err := Get(ts.URL + suffix)
+ if err != nil {
+ t.Fatalf("Get %s: %v", suffix, err)
+ }
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("ReadAll %s: %v", suffix, err)
+ }
+ return string(b)
+ }
+ if s := get("/bar/"); !strings.Contains(s, ">foo.txt<") {
+ t.Logf("expected a directory listing with foo.txt, got %q", s)
+ }
+ if s := get("/bar/foo.txt"); s != "Hello world" {
+ t.Logf("expected %q, got %q", "Hello world", s)
+ }
+}
+
+func TestDirJoin(t *testing.T) {
+ wfi, err := os.Stat("/etc/hosts")
+ if err != nil {
+ t.Logf("skipping test; no /etc/hosts file")
+ return
+ }
+ test := func(d Dir, name string) {
+ f, err := d.Open(name)
+ if err != nil {
+ t.Fatalf("open of %s: %v", name, err)
+ }
+ defer f.Close()
+ gfi, err := f.Stat()
+ if err != nil {
+ t.Fatalf("stat of %s: %v", name, err)
+ }
+ if !gfi.(*os.FileStat).SameFile(wfi.(*os.FileStat)) {
+ t.Errorf("%s got different file", name)
+ }
+ }
+ test(Dir("/etc/"), "/hosts")
+ test(Dir("/etc/"), "hosts")
+ test(Dir("/etc/"), "../../../../hosts")
+ test(Dir("/etc"), "/hosts")
+ test(Dir("/etc"), "hosts")
+ test(Dir("/etc"), "../../../../hosts")
+
+ // Not really directories, but since we use this trick in
+ // ServeFile, test it:
+ test(Dir("/etc/hosts"), "")
+ test(Dir("/etc/hosts"), "/")
+ test(Dir("/etc/hosts"), "../")
+}
+
+func TestEmptyDirOpenCWD(t *testing.T) {
+ test := func(d Dir) {
+ name := "fs_test.go"
+ f, err := d.Open(name)
+ if err != nil {
+ t.Fatalf("open of %s: %v", name, err)
+ }
+ defer f.Close()
+ }
+ test(Dir(""))
+ test(Dir("."))
+ test(Dir("./"))
+}
+
+func TestServeFileContentType(t *testing.T) {
+ const ctype = "icecream/chocolate"
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.FormValue("override") == "1" {
+ w.Header().Set("Content-Type", ctype)
+ }
+ ServeFile(w, r, "testdata/file")
+ }))
+ defer ts.Close()
+ get := func(override, want string) {
+ resp, err := Get(ts.URL + "?override=" + override)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if h := resp.Header.Get("Content-Type"); h != want {
+ t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
+ }
+ }
+ get("0", "text/plain; charset=utf-8")
+ get("1", ctype)
+}
+
+func TestServeFileMimeType(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "testdata/style.css")
+ }))
+ defer ts.Close()
+ resp, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "text/css; charset=utf-8"
+ if h := resp.Header.Get("Content-Type"); h != want {
+ t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
+ }
+}
+
+func TestServeFileFromCWD(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "fs_test.go")
+ }))
+ defer ts.Close()
+ r, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if r.StatusCode != 200 {
+ t.Fatalf("expected 200 OK, got %s", r.Status)
+ }
+}
+
+func TestServeFileWithContentEncoding(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "foo")
+ ServeFile(w, r, "testdata/file")
+ }))
+ defer ts.Close()
+ resp, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := resp.ContentLength, int64(-1); g != e {
+ t.Errorf("Content-Length mismatch: got %d, want %d", g, e)
+ }
+}
+
+func TestServeIndexHtml(t *testing.T) {
+ const want = "index.html says hello\n"
+ ts := httptest.NewServer(FileServer(Dir(".")))
+ defer ts.Close()
+
+ for _, path := range []string{"/testdata/", "/testdata/index.html"} {
+ res, err := Get(ts.URL + path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if s := string(b); s != want {
+ t.Errorf("for path %q got %q, want %q", path, s, want)
+ }
+ }
+}
+
+func getBody(t *testing.T, req Request) (*Response, []byte) {
+ r, err := DefaultClient.Do(&req)
+ if err != nil {
+ t.Fatal(req.URL.String(), "send:", err)
+ }
+ b, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ return r, b
+}
+
+func equal(a, b []byte) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i] != b[i] {
+ return false
+ }
+ }
+ return true
+}
diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go
new file mode 100644
index 000000000..b107c312d
--- /dev/null
+++ b/src/pkg/net/http/header.go
@@ -0,0 +1,78 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "fmt"
+ "io"
+ "net/textproto"
+ "sort"
+ "strings"
+)
+
+// A Header represents the key-value pairs in an HTTP header.
+type Header map[string][]string
+
+// Add adds the key, value pair to the header.
+// It appends to any existing values associated with key.
+func (h Header) Add(key, value string) {
+ textproto.MIMEHeader(h).Add(key, value)
+}
+
+// Set sets the header entries associated with key to
+// the single element value. It replaces any existing
+// values associated with key.
+func (h Header) Set(key, value string) {
+ textproto.MIMEHeader(h).Set(key, value)
+}
+
+// Get gets the first value associated with the given key.
+// If there are no values associated with the key, Get returns "".
+// To access multiple values of a key, access the map directly
+// with CanonicalHeaderKey.
+func (h Header) Get(key string) string {
+ return textproto.MIMEHeader(h).Get(key)
+}
+
+// Del deletes the values associated with key.
+func (h Header) Del(key string) {
+ textproto.MIMEHeader(h).Del(key)
+}
+
+// Write writes a header in wire format.
+func (h Header) Write(w io.Writer) error {
+ return h.WriteSubset(w, nil)
+}
+
+var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ")
+
+// WriteSubset writes a header in wire format.
+// If exclude is not nil, keys where exclude[key] == true are not written.
+func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
+ keys := make([]string, 0, len(h))
+ for k := range h {
+ if exclude == nil || !exclude[k] {
+ keys = append(keys, k)
+ }
+ }
+ sort.Strings(keys)
+ for _, k := range keys {
+ for _, v := range h[k] {
+ v = headerNewlineToSpace.Replace(v)
+ v = strings.TrimSpace(v)
+ if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// CanonicalHeaderKey returns the canonical format of the
+// header key s. The canonicalization converts the first
+// letter and any letter following a hyphen to upper case;
+// the rest are converted to lowercase. For example, the
+// canonical key for "accept-encoding" is "Accept-Encoding".
+func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go
new file mode 100644
index 000000000..ccdee8a97
--- /dev/null
+++ b/src/pkg/net/http/header_test.go
@@ -0,0 +1,81 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "testing"
+)
+
+var headerWriteTests = []struct {
+ h Header
+ exclude map[string]bool
+ expected string
+}{
+ {Header{}, nil, ""},
+ {
+ Header{
+ "Content-Type": {"text/html; charset=UTF-8"},
+ "Content-Length": {"0"},
+ },
+ nil,
+ "Content-Length: 0\r\nContent-Type: text/html; charset=UTF-8\r\n",
+ },
+ {
+ Header{
+ "Content-Length": {"0", "1", "2"},
+ },
+ nil,
+ "Content-Length: 0\r\nContent-Length: 1\r\nContent-Length: 2\r\n",
+ },
+ {
+ Header{
+ "Expires": {"-1"},
+ "Content-Length": {"0"},
+ "Content-Encoding": {"gzip"},
+ },
+ map[string]bool{"Content-Length": true},
+ "Content-Encoding: gzip\r\nExpires: -1\r\n",
+ },
+ {
+ Header{
+ "Expires": {"-1"},
+ "Content-Length": {"0", "1", "2"},
+ "Content-Encoding": {"gzip"},
+ },
+ map[string]bool{"Content-Length": true},
+ "Content-Encoding: gzip\r\nExpires: -1\r\n",
+ },
+ {
+ Header{
+ "Expires": {"-1"},
+ "Content-Length": {"0"},
+ "Content-Encoding": {"gzip"},
+ },
+ map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true},
+ "",
+ },
+ {
+ Header{
+ "Nil": nil,
+ "Empty": {},
+ "Blank": {""},
+ "Double-Blank": {"", ""},
+ },
+ nil,
+ "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n",
+ },
+}
+
+func TestHeaderWrite(t *testing.T) {
+ var buf bytes.Buffer
+ for i, test := range headerWriteTests {
+ test.h.WriteSubset(&buf, test.exclude)
+ if buf.String() != test.expected {
+ t.Errorf("#%d:\n got: %q\nwant: %q", i, buf.String(), test.expected)
+ }
+ buf.Reset()
+ }
+}
diff --git a/src/pkg/net/http/httptest/Makefile b/src/pkg/net/http/httptest/Makefile
new file mode 100644
index 000000000..3bb445419
--- /dev/null
+++ b/src/pkg/net/http/httptest/Makefile
@@ -0,0 +1,12 @@
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../../Make.inc
+
+TARG=net/http/httptest
+GOFILES=\
+ recorder.go\
+ server.go\
+
+include ../../../../Make.pkg
diff --git a/src/pkg/net/http/httptest/recorder.go b/src/pkg/net/http/httptest/recorder.go
new file mode 100644
index 000000000..9aa0d510b
--- /dev/null
+++ b/src/pkg/net/http/httptest/recorder.go
@@ -0,0 +1,58 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httptest provides utilities for HTTP testing.
+package httptest
+
+import (
+ "bytes"
+ "net/http"
+)
+
+// ResponseRecorder is an implementation of http.ResponseWriter that
+// records its mutations for later inspection in tests.
+type ResponseRecorder struct {
+ Code int // the HTTP response code from WriteHeader
+ HeaderMap http.Header // the HTTP response headers
+ Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
+ Flushed bool
+}
+
+// NewRecorder returns an initialized ResponseRecorder.
+func NewRecorder() *ResponseRecorder {
+ return &ResponseRecorder{
+ HeaderMap: make(http.Header),
+ Body: new(bytes.Buffer),
+ }
+}
+
+// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
+// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
+const DefaultRemoteAddr = "1.2.3.4"
+
+// Header returns the response headers.
+func (rw *ResponseRecorder) Header() http.Header {
+ return rw.HeaderMap
+}
+
+// Write always succeeds and writes to rw.Body, if not nil.
+func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
+ if rw.Body != nil {
+ rw.Body.Write(buf)
+ }
+ if rw.Code == 0 {
+ rw.Code = http.StatusOK
+ }
+ return len(buf), nil
+}
+
+// WriteHeader sets rw.Code.
+func (rw *ResponseRecorder) WriteHeader(code int) {
+ rw.Code = code
+}
+
+// Flush sets rw.Flushed to true.
+func (rw *ResponseRecorder) Flush() {
+ rw.Flushed = true
+}
diff --git a/src/pkg/net/http/httptest/server.go b/src/pkg/net/http/httptest/server.go
new file mode 100644
index 000000000..5b02e143d
--- /dev/null
+++ b/src/pkg/net/http/httptest/server.go
@@ -0,0 +1,173 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Implementation of Server
+
+package httptest
+
+import (
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "net"
+ "net/http"
+ "os"
+)
+
+// A Server is an HTTP server listening on a system-chosen port on the
+// local loopback interface, for use in end-to-end HTTP tests.
+type Server struct {
+ URL string // base URL of form http://ipaddr:port with no trailing slash
+ Listener net.Listener
+ TLS *tls.Config // nil if not using using TLS
+
+ // Config may be changed after calling NewUnstartedServer and
+ // before Start or StartTLS.
+ Config *http.Server
+}
+
+// historyListener keeps track of all connections that it's ever
+// accepted.
+type historyListener struct {
+ net.Listener
+ history []net.Conn
+}
+
+func (hs *historyListener) Accept() (c net.Conn, err error) {
+ c, err = hs.Listener.Accept()
+ if err == nil {
+ hs.history = append(hs.history, c)
+ }
+ return
+}
+
+func newLocalListener() net.Listener {
+ if *serve != "" {
+ l, err := net.Listen("tcp", *serve)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
+ }
+ return l
+ }
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
+ }
+ }
+ return l
+}
+
+// When debugging a particular http server-based test,
+// this flag lets you run
+// gotest -run=BrokenTest -httptest.serve=127.0.0.1:8000
+// to start the broken server so you can interact with it manually.
+var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
+
+// NewServer starts and returns a new Server.
+// The caller should call Close when finished, to shut it down.
+func NewServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.Start()
+ return ts
+}
+
+// NewUnstartedServer returns a new Server but doesn't start it.
+//
+// After changing its configuration, the caller should call Start or
+// StartTLS.
+//
+// The caller should call Close when finished, to shut it down.
+func NewUnstartedServer(handler http.Handler) *Server {
+ return &Server{
+ Listener: newLocalListener(),
+ Config: &http.Server{Handler: handler},
+ }
+}
+
+// Start starts a server from NewUnstartedServer.
+func (s *Server) Start() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)}
+ s.URL = "http://" + s.Listener.Addr().String()
+ go s.Config.Serve(s.Listener)
+ if *serve != "" {
+ fmt.Println(os.Stderr, "httptest: serving on", s.URL)
+ select {}
+ }
+}
+
+// StartTLS starts TLS on a server from NewUnstartedServer.
+func (s *Server) StartTLS() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ cert, err := tls.X509KeyPair(localhostCert, localhostKey)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+ }
+
+ s.TLS = &tls.Config{
+ NextProtos: []string{"http/1.1"},
+ Certificates: []tls.Certificate{cert},
+ }
+ tlsListener := tls.NewListener(s.Listener, s.TLS)
+
+ s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)}
+ s.URL = "https://" + s.Listener.Addr().String()
+ go s.Config.Serve(s.Listener)
+}
+
+// NewTLSServer starts and returns a new Server using TLS.
+// The caller should call Close when finished, to shut it down.
+func NewTLSServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.StartTLS()
+ return ts
+}
+
+// Close shuts down the server.
+func (s *Server) Close() {
+ s.Listener.Close()
+}
+
+// CloseClientConnections closes any currently open HTTP connections
+// to the test Server.
+func (s *Server) CloseClientConnections() {
+ hl, ok := s.Listener.(*historyListener)
+ if !ok {
+ return
+ }
+ for _, conn := range hl.history {
+ conn.Close()
+ }
+}
+
+// localhostCert is a PEM-encoded TLS cert with SAN DNS names
+// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
+// of ASN.1 time).
+var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
+MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
+DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7
+qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL
+8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw
+DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG
+CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7
++l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM=
+-----END CERTIFICATE-----
+`)
+
+// localhostKey is the private key for localhostCert.
+var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v
+tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi
+SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0
+3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ
+/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN
+poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh
+AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93
+-----END RSA PRIVATE KEY-----
+`)
diff --git a/src/pkg/net/http/httputil/Makefile b/src/pkg/net/http/httputil/Makefile
new file mode 100644
index 000000000..8bfc7a022
--- /dev/null
+++ b/src/pkg/net/http/httputil/Makefile
@@ -0,0 +1,14 @@
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../../Make.inc
+
+TARG=net/http/httputil
+GOFILES=\
+ chunked.go\
+ dump.go\
+ persist.go\
+ reverseproxy.go\
+
+include ../../../../Make.pkg
diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go
new file mode 100644
index 000000000..29eaf3475
--- /dev/null
+++ b/src/pkg/net/http/httputil/chunked.go
@@ -0,0 +1,172 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The wire protocol for HTTP's "chunked" Transfer-Encoding.
+
+// This code is a duplicate of ../chunked.go with these edits:
+// s/newChunked/NewChunked/g
+// s/package http/package httputil/
+// Please make any changes in both files.
+
+package httputil
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "io"
+ "strconv"
+)
+
+const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
+
+var ErrLineTooLong = errors.New("header line too long")
+
+// NewChunkedReader returns a new chunkedReader that translates the data read from r
+// out of HTTP "chunked" format before returning it.
+// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+//
+// NewChunkedReader is not needed by normal applications. The http package
+// automatically decodes chunking when reading response bodies.
+func NewChunkedReader(r io.Reader) io.Reader {
+ br, ok := r.(*bufio.Reader)
+ if !ok {
+ br = bufio.NewReader(r)
+ }
+ return &chunkedReader{r: br}
+}
+
+type chunkedReader struct {
+ r *bufio.Reader
+ n uint64 // unread bytes in chunk
+ err error
+}
+
+func (cr *chunkedReader) beginChunk() {
+ // chunk-size CRLF
+ var line string
+ line, cr.err = readLine(cr.r)
+ if cr.err != nil {
+ return
+ }
+ cr.n, cr.err = strconv.ParseUint(line, 16, 64)
+ if cr.err != nil {
+ return
+ }
+ if cr.n == 0 {
+ cr.err = io.EOF
+ }
+}
+
+func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
+ if cr.err != nil {
+ return 0, cr.err
+ }
+ if cr.n == 0 {
+ cr.beginChunk()
+ if cr.err != nil {
+ return 0, cr.err
+ }
+ }
+ if uint64(len(b)) > cr.n {
+ b = b[0:cr.n]
+ }
+ n, cr.err = cr.r.Read(b)
+ cr.n -= uint64(n)
+ if cr.n == 0 && cr.err == nil {
+ // end of chunk (CRLF)
+ b := make([]byte, 2)
+ if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
+ if b[0] != '\r' || b[1] != '\n' {
+ cr.err = errors.New("malformed chunked encoding")
+ }
+ }
+ }
+ return n, cr.err
+}
+
+// Read a line of bytes (up to \n) from b.
+// Give up if the line exceeds maxLineLength.
+// The returned bytes are a pointer into storage in
+// the bufio, so they are only valid until the next bufio read.
+func readLineBytes(b *bufio.Reader) (p []byte, err error) {
+ if p, err = b.ReadSlice('\n'); err != nil {
+ // We always know when EOF is coming.
+ // If the caller asked for a line, there should be a line.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ } else if err == bufio.ErrBufferFull {
+ err = ErrLineTooLong
+ }
+ return nil, err
+ }
+ if len(p) >= maxLineLength {
+ return nil, ErrLineTooLong
+ }
+
+ // Chop off trailing white space.
+ p = bytes.TrimRight(p, " \r\t\n")
+
+ return p, nil
+}
+
+// readLineBytes, but convert the bytes into a string.
+func readLine(b *bufio.Reader) (s string, err error) {
+ p, e := readLineBytes(b)
+ if e != nil {
+ return "", e
+ }
+ return string(p), nil
+}
+
+// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
+// "chunked" format before writing them to w. Closing the returned chunkedWriter
+// sends the final 0-length chunk that marks the end of the stream.
+//
+// NewChunkedWriter is not needed by normal applications. The http
+// package adds chunking automatically if handlers don't set a
+// Content-Length header. Using NewChunkedWriter inside a handler
+// would result in double chunking or chunking with a Content-Length
+// length, both of which are wrong.
+func NewChunkedWriter(w io.Writer) io.WriteCloser {
+ return &chunkedWriter{w}
+}
+
+// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
+// Encoding wire format to the underlying Wire chunkedWriter.
+type chunkedWriter struct {
+ Wire io.Writer
+}
+
+// Write the contents of data as one chunk to Wire.
+// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has
+// a bug since it does not check for success of io.WriteString
+func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
+
+ // Don't send 0-length data. It looks like EOF for chunked encoding.
+ if len(data) == 0 {
+ return 0, nil
+ }
+
+ head := strconv.FormatInt(int64(len(data)), 16) + "\r\n"
+
+ if _, err = io.WriteString(cw.Wire, head); err != nil {
+ return 0, err
+ }
+ if n, err = cw.Wire.Write(data); err != nil {
+ return
+ }
+ if n != len(data) {
+ err = io.ErrShortWrite
+ return
+ }
+ _, err = io.WriteString(cw.Wire, "\r\n")
+
+ return
+}
+
+func (cw *chunkedWriter) Close() error {
+ _, err := io.WriteString(cw.Wire, "0\r\n")
+ return err
+}
diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go
new file mode 100644
index 000000000..155a32bdf
--- /dev/null
+++ b/src/pkg/net/http/httputil/chunked_test.go
@@ -0,0 +1,41 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This code is a duplicate of ../chunked_test.go with these edits:
+// s/newChunked/NewChunked/g
+// s/package http/package httputil/
+// Please make any changes in both files.
+
+package httputil
+
+import (
+ "bytes"
+ "io/ioutil"
+ "testing"
+)
+
+func TestChunk(t *testing.T) {
+ var b bytes.Buffer
+
+ w := NewChunkedWriter(&b)
+ const chunk1 = "hello, "
+ const chunk2 = "world! 0123456789abcdef"
+ w.Write([]byte(chunk1))
+ w.Write([]byte(chunk2))
+ w.Close()
+
+ if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e {
+ t.Fatalf("chunk writer wrote %q; want %q", g, e)
+ }
+
+ r := NewChunkedReader(&b)
+ data, err := ioutil.ReadAll(r)
+ if err != nil {
+ t.Logf(`data: "%s"`, data)
+ t.Fatalf("ReadAll from reader: %v", err)
+ }
+ if g, e := string(data), chunk1+chunk2; g != e {
+ t.Errorf("chunk reader read %q; want %q", g, e)
+ }
+}
diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go
new file mode 100644
index 000000000..b8a98ee42
--- /dev/null
+++ b/src/pkg/net/http/httputil/dump.go
@@ -0,0 +1,196 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "strings"
+ "time"
+)
+
+// One of the copies, say from b to r2, could be avoided by using a more
+// elaborate trick where the other copy is made during Request/Response.Write.
+// This would complicate things too much, given that these functions are for
+// debugging only.
+func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
+ var buf bytes.Buffer
+ if _, err = buf.ReadFrom(b); err != nil {
+ return nil, nil, err
+ }
+ if err = b.Close(); err != nil {
+ return nil, nil, err
+ }
+ return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewBuffer(buf.Bytes())), nil
+}
+
+// dumpConn is a net.Conn which writes to Writer and reads from Reader
+type dumpConn struct {
+ io.Writer
+ io.Reader
+}
+
+func (c *dumpConn) Close() error { return nil }
+func (c *dumpConn) LocalAddr() net.Addr { return nil }
+func (c *dumpConn) RemoteAddr() net.Addr { return nil }
+func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
+
+// DumpRequestOut is like DumpRequest but includes
+// headers that the standard http.Transport adds,
+// such as User-Agent.
+func DumpRequestOut(req *http.Request, body bool) (dump []byte, err error) {
+ save := req.Body
+ if !body || req.Body == nil {
+ req.Body = nil
+ } else {
+ save, req.Body, err = drainBody(req.Body)
+ if err != nil {
+ return
+ }
+ }
+
+ var b bytes.Buffer
+ dialed := false
+ t := &http.Transport{
+ Dial: func(net, addr string) (c net.Conn, err error) {
+ if dialed {
+ return nil, errors.New("unexpected second dial")
+ }
+ c = &dumpConn{
+ Writer: &b,
+ Reader: strings.NewReader("HTTP/1.1 500 Fake Error\r\n\r\n"),
+ }
+ return
+ },
+ }
+
+ _, err = t.RoundTrip(req)
+
+ req.Body = save
+ if err != nil {
+ return
+ }
+ dump = b.Bytes()
+ return
+}
+
+// Return value if nonempty, def otherwise.
+func valueOrDefault(value, def string) string {
+ if value != "" {
+ return value
+ }
+ return def
+}
+
+var reqWriteExcludeHeaderDump = map[string]bool{
+ "Host": true, // not in Header map anyway
+ "Content-Length": true,
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// dumpAsReceived writes req to w in the form as it was received, or
+// at least as accurately as possible from the information retained in
+// the request.
+func dumpAsReceived(req *http.Request, w io.Writer) error {
+ return nil
+}
+
+// DumpRequest returns the as-received wire representation of req,
+// optionally including the request body, for debugging.
+// DumpRequest is semantically a no-op, but in order to
+// dump the body, it reads the body data into memory and
+// changes req.Body to refer to the in-memory copy.
+// The documentation for http.Request.Write details which fields
+// of req are used.
+func DumpRequest(req *http.Request, body bool) (dump []byte, err error) {
+ save := req.Body
+ if !body || req.Body == nil {
+ req.Body = nil
+ } else {
+ save, req.Body, err = drainBody(req.Body)
+ if err != nil {
+ return
+ }
+ }
+
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"),
+ req.URL.RequestURI(), req.ProtoMajor, req.ProtoMinor)
+
+ host := req.Host
+ if host == "" && req.URL != nil {
+ host = req.URL.Host
+ }
+ if host != "" {
+ fmt.Fprintf(&b, "Host: %s\r\n", host)
+ }
+
+ chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked"
+ if len(req.TransferEncoding) > 0 {
+ fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ","))
+ }
+ if req.Close {
+ fmt.Fprintf(&b, "Connection: close\r\n")
+ }
+
+ err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump)
+ if err != nil {
+ return
+ }
+
+ io.WriteString(&b, "\r\n")
+
+ if req.Body != nil {
+ var dest io.Writer = &b
+ if chunked {
+ dest = NewChunkedWriter(dest)
+ }
+ _, err = io.Copy(dest, req.Body)
+ if chunked {
+ dest.(io.Closer).Close()
+ io.WriteString(&b, "\r\n")
+ }
+ }
+
+ req.Body = save
+ if err != nil {
+ return
+ }
+ dump = b.Bytes()
+ return
+}
+
+// DumpResponse is like DumpRequest but dumps a response.
+func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) {
+ var b bytes.Buffer
+ save := resp.Body
+ savecl := resp.ContentLength
+ if !body || resp.Body == nil {
+ resp.Body = nil
+ resp.ContentLength = 0
+ } else {
+ save, resp.Body, err = drainBody(resp.Body)
+ if err != nil {
+ return
+ }
+ }
+ err = resp.Write(&b)
+ resp.Body = save
+ resp.ContentLength = savecl
+ if err != nil {
+ return
+ }
+ dump = b.Bytes()
+ return
+}
diff --git a/src/pkg/net/http/httputil/dump_test.go b/src/pkg/net/http/httputil/dump_test.go
new file mode 100644
index 000000000..819efb584
--- /dev/null
+++ b/src/pkg/net/http/httputil/dump_test.go
@@ -0,0 +1,140 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "testing"
+)
+
+type dumpTest struct {
+ Req http.Request
+ Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body
+
+ WantDump string
+ WantDumpOut string
+}
+
+var dumpTests = []dumpTest{
+
+ // HTTP/1.1 => chunked coding; body; empty trailer
+ {
+ Req: http.Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantDump: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+
+ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
+ // and doesn't add a User-Agent.
+ {
+ Req: http.Request{
+ Method: "GET",
+ URL: mustParseURL("/foo"),
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Header: http.Header{
+ "X-Foo": []string{"X-Bar"},
+ },
+ },
+
+ WantDump: "GET /foo HTTP/1.0\r\n" +
+ "X-Foo: X-Bar\r\n\r\n",
+ },
+
+ {
+ Req: *mustNewRequest("GET", "http://example.com/foo", nil),
+
+ WantDumpOut: "GET /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+}
+
+func TestDumpRequest(t *testing.T) {
+ for i, tt := range dumpTests {
+ setBody := func() {
+ if tt.Body == nil {
+ return
+ }
+ switch b := tt.Body.(type) {
+ case []byte:
+ tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b))
+ case func() io.ReadCloser:
+ tt.Req.Body = b()
+ }
+ }
+ setBody()
+ if tt.Req.Header == nil {
+ tt.Req.Header = make(http.Header)
+ }
+
+ if tt.WantDump != "" {
+ setBody()
+ dump, err := DumpRequest(&tt.Req, true)
+ if err != nil {
+ t.Errorf("DumpRequest #%d: %s", i, err)
+ continue
+ }
+ if string(dump) != tt.WantDump {
+ t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump))
+ continue
+ }
+ }
+
+ if tt.WantDumpOut != "" {
+ setBody()
+ dump, err := DumpRequestOut(&tt.Req, true)
+ if err != nil {
+ t.Errorf("DumpRequestOut #%d: %s", i, err)
+ continue
+ }
+ if string(dump) != tt.WantDumpOut {
+ t.Errorf("DumpRequestOut %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDumpOut, string(dump))
+ continue
+ }
+ }
+ }
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil {
+ panic(fmt.Sprintf("Error parsing URL %q: %v", s, err))
+ }
+ return u
+}
+
+func mustNewRequest(method, url string, body io.Reader) *http.Request {
+ req, err := http.NewRequest(method, url, body)
+ if err != nil {
+ panic(fmt.Sprintf("NewRequest(%q, %q, %p) err = %v", method, url, body, err))
+ }
+ return req
+}
diff --git a/src/pkg/net/http/httputil/persist.go b/src/pkg/net/http/httputil/persist.go
new file mode 100644
index 000000000..1266bd3ad
--- /dev/null
+++ b/src/pkg/net/http/httputil/persist.go
@@ -0,0 +1,422 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httputil provides HTTP utility functions, complementing the
+// more common ones in the net/http package.
+package httputil
+
+import (
+ "bufio"
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "net/textproto"
+ "os"
+ "sync"
+)
+
+var (
+ ErrPersistEOF = &http.ProtocolError{"persistent connection closed"}
+ ErrPipeline = &http.ProtocolError{"pipeline error"}
+)
+
+// This is an API usage error - the local side is closed.
+// ErrPersistEOF (above) reports that the remote side is closed.
+var errClosed = errors.New("i/o operation on closed connection")
+
+// A ServerConn reads requests and sends responses over an underlying
+// connection, until the HTTP keepalive logic commands an end. ServerConn
+// also allows hijacking the underlying connection by calling Hijack
+// to regain control over the connection. ServerConn supports pipe-lining,
+// i.e. requests can be read out of sync (but in the same order) while the
+// respective responses are sent.
+//
+// ServerConn is low-level and should not be needed by most applications.
+// See Server.
+type ServerConn struct {
+ lk sync.Mutex // read-write protects the following fields
+ c net.Conn
+ r *bufio.Reader
+ re, we error // read/write errors
+ lastbody io.ReadCloser
+ nread, nwritten int
+ pipereq map[*http.Request]uint
+
+ pipe textproto.Pipeline
+}
+
+// NewServerConn returns a new ServerConn reading and writing c. If r is not
+// nil, it is the buffer to use when reading c.
+func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn {
+ if r == nil {
+ r = bufio.NewReader(c)
+ }
+ return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)}
+}
+
+// Hijack detaches the ServerConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
+// called before Read has signaled the end of the keep-alive logic. The user
+// should not call Hijack while Read or Write is in progress.
+func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) {
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ c = sc.c
+ r = sc.r
+ sc.c = nil
+ sc.r = nil
+ return
+}
+
+// Close calls Hijack and then also closes the underlying connection
+func (sc *ServerConn) Close() error {
+ c, _ := sc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
+// Read returns the next request on the wire. An ErrPersistEOF is returned if
+// it is gracefully determined that there are no more requests (e.g. after the
+// first request on an HTTP/1.0 connection, or after a Connection:close on a
+// HTTP/1.1 connection).
+func (sc *ServerConn) Read() (req *http.Request, err error) {
+
+ // Ensure ordered execution of Reads and Writes
+ id := sc.pipe.Next()
+ sc.pipe.StartRequest(id)
+ defer func() {
+ sc.pipe.EndRequest(id)
+ if req == nil {
+ sc.pipe.StartResponse(id)
+ sc.pipe.EndResponse(id)
+ } else {
+ // Remember the pipeline id of this request
+ sc.lk.Lock()
+ sc.pipereq[req] = id
+ sc.lk.Unlock()
+ }
+ }()
+
+ sc.lk.Lock()
+ if sc.we != nil { // no point receiving if write-side broken or closed
+ defer sc.lk.Unlock()
+ return nil, sc.we
+ }
+ if sc.re != nil {
+ defer sc.lk.Unlock()
+ return nil, sc.re
+ }
+ if sc.r == nil { // connection closed by user in the meantime
+ defer sc.lk.Unlock()
+ return nil, errClosed
+ }
+ r := sc.r
+ lastbody := sc.lastbody
+ sc.lastbody = nil
+ sc.lk.Unlock()
+
+ // Make sure body is fully consumed, even if user does not call body.Close
+ if lastbody != nil {
+ // body.Close is assumed to be idempotent and multiple calls to
+ // it should return the error that its first invocation
+ // returned.
+ err = lastbody.Close()
+ if err != nil {
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ sc.re = err
+ return nil, err
+ }
+ }
+
+ req, err = http.ReadRequest(r)
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ if err != nil {
+ if err == io.ErrUnexpectedEOF {
+ // A close from the opposing client is treated as a
+ // graceful close, even if there was some unparse-able
+ // data before the close.
+ sc.re = ErrPersistEOF
+ return nil, sc.re
+ } else {
+ sc.re = err
+ return req, err
+ }
+ }
+ sc.lastbody = req.Body
+ sc.nread++
+ if req.Close {
+ sc.re = ErrPersistEOF
+ return req, sc.re
+ }
+ return req, err
+}
+
+// Pending returns the number of unanswered requests
+// that have been received on the connection.
+func (sc *ServerConn) Pending() int {
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ return sc.nread - sc.nwritten
+}
+
+// Write writes resp in response to req. To close the connection gracefully, set the
+// Response.Close field to true. Write should be considered operational until
+// it returns an error, regardless of any errors returned on the Read side.
+func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error {
+
+ // Retrieve the pipeline ID of this request/response pair
+ sc.lk.Lock()
+ id, ok := sc.pipereq[req]
+ delete(sc.pipereq, req)
+ if !ok {
+ sc.lk.Unlock()
+ return ErrPipeline
+ }
+ sc.lk.Unlock()
+
+ // Ensure pipeline order
+ sc.pipe.StartResponse(id)
+ defer sc.pipe.EndResponse(id)
+
+ sc.lk.Lock()
+ if sc.we != nil {
+ defer sc.lk.Unlock()
+ return sc.we
+ }
+ if sc.c == nil { // connection closed by user in the meantime
+ defer sc.lk.Unlock()
+ return os.EBADF
+ }
+ c := sc.c
+ if sc.nread <= sc.nwritten {
+ defer sc.lk.Unlock()
+ return errors.New("persist server pipe count")
+ }
+ if resp.Close {
+ // After signaling a keep-alive close, any pipelined unread
+ // requests will be lost. It is up to the user to drain them
+ // before signaling.
+ sc.re = ErrPersistEOF
+ }
+ sc.lk.Unlock()
+
+ err := resp.Write(c)
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ if err != nil {
+ sc.we = err
+ return err
+ }
+ sc.nwritten++
+
+ return nil
+}
+
+// A ClientConn sends request and receives headers over an underlying
+// connection, while respecting the HTTP keepalive logic. ClientConn
+// supports hijacking the connection calling Hijack to
+// regain control of the underlying net.Conn and deal with it as desired.
+//
+// ClientConn is low-level and should not be needed by most applications.
+// See Client.
+type ClientConn struct {
+ lk sync.Mutex // read-write protects the following fields
+ c net.Conn
+ r *bufio.Reader
+ re, we error // read/write errors
+ lastbody io.ReadCloser
+ nread, nwritten int
+ pipereq map[*http.Request]uint
+
+ pipe textproto.Pipeline
+ writeReq func(*http.Request, io.Writer) error
+}
+
+// NewClientConn returns a new ClientConn reading and writing c. If r is not
+// nil, it is the buffer to use when reading c.
+func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
+ if r == nil {
+ r = bufio.NewReader(c)
+ }
+ return &ClientConn{
+ c: c,
+ r: r,
+ pipereq: make(map[*http.Request]uint),
+ writeReq: (*http.Request).Write,
+ }
+}
+
+// NewProxyClientConn works like NewClientConn but writes Requests
+// using Request's WriteProxy method.
+func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
+ cc := NewClientConn(c, r)
+ cc.writeReq = (*http.Request).WriteProxy
+ return cc
+}
+
+// Hijack detaches the ClientConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
+// called before the user or Read have signaled the end of the keep-alive
+// logic. The user should not call Hijack while Read or Write is in progress.
+func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) {
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ c = cc.c
+ r = cc.r
+ cc.c = nil
+ cc.r = nil
+ return
+}
+
+// Close calls Hijack and then also closes the underlying connection
+func (cc *ClientConn) Close() error {
+ c, _ := cc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
+// Write writes a request. An ErrPersistEOF error is returned if the connection
+// has been closed in an HTTP keepalive sense. If req.Close equals true, the
+// keepalive connection is logically closed after this request and the opposing
+// server is informed. An ErrUnexpectedEOF indicates the remote closed the
+// underlying TCP connection, which is usually considered as graceful close.
+func (cc *ClientConn) Write(req *http.Request) (err error) {
+
+ // Ensure ordered execution of Writes
+ id := cc.pipe.Next()
+ cc.pipe.StartRequest(id)
+ defer func() {
+ cc.pipe.EndRequest(id)
+ if err != nil {
+ cc.pipe.StartResponse(id)
+ cc.pipe.EndResponse(id)
+ } else {
+ // Remember the pipeline id of this request
+ cc.lk.Lock()
+ cc.pipereq[req] = id
+ cc.lk.Unlock()
+ }
+ }()
+
+ cc.lk.Lock()
+ if cc.re != nil { // no point sending if read-side closed or broken
+ defer cc.lk.Unlock()
+ return cc.re
+ }
+ if cc.we != nil {
+ defer cc.lk.Unlock()
+ return cc.we
+ }
+ if cc.c == nil { // connection closed by user in the meantime
+ defer cc.lk.Unlock()
+ return errClosed
+ }
+ c := cc.c
+ if req.Close {
+ // We write the EOF to the write-side error, because there
+ // still might be some pipelined reads
+ cc.we = ErrPersistEOF
+ }
+ cc.lk.Unlock()
+
+ err = cc.writeReq(req, c)
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ if err != nil {
+ cc.we = err
+ return err
+ }
+ cc.nwritten++
+
+ return nil
+}
+
+// Pending returns the number of unanswered requests
+// that have been sent on the connection.
+func (cc *ClientConn) Pending() int {
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ return cc.nwritten - cc.nread
+}
+
+// Read reads the next response from the wire. A valid response might be
+// returned together with an ErrPersistEOF, which means that the remote
+// requested that this be the last request serviced. Read can be called
+// concurrently with Write, but not with another Read.
+func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) {
+ // Retrieve the pipeline ID of this request/response pair
+ cc.lk.Lock()
+ id, ok := cc.pipereq[req]
+ delete(cc.pipereq, req)
+ if !ok {
+ cc.lk.Unlock()
+ return nil, ErrPipeline
+ }
+ cc.lk.Unlock()
+
+ // Ensure pipeline order
+ cc.pipe.StartResponse(id)
+ defer cc.pipe.EndResponse(id)
+
+ cc.lk.Lock()
+ if cc.re != nil {
+ defer cc.lk.Unlock()
+ return nil, cc.re
+ }
+ if cc.r == nil { // connection closed by user in the meantime
+ defer cc.lk.Unlock()
+ return nil, errClosed
+ }
+ r := cc.r
+ lastbody := cc.lastbody
+ cc.lastbody = nil
+ cc.lk.Unlock()
+
+ // Make sure body is fully consumed, even if user does not call body.Close
+ if lastbody != nil {
+ // body.Close is assumed to be idempotent and multiple calls to
+ // it should return the error that its first invokation
+ // returned.
+ err = lastbody.Close()
+ if err != nil {
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ cc.re = err
+ return nil, err
+ }
+ }
+
+ resp, err = http.ReadResponse(r, req)
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ if err != nil {
+ cc.re = err
+ return resp, err
+ }
+ cc.lastbody = resp.Body
+
+ cc.nread++
+
+ if resp.Close {
+ cc.re = ErrPersistEOF // don't send any more requests
+ return resp, cc.re
+ }
+ return resp, err
+}
+
+// Do is convenience method that writes a request and reads a response.
+func (cc *ClientConn) Do(req *http.Request) (resp *http.Response, err error) {
+ err = cc.Write(req)
+ if err != nil {
+ return
+ }
+ return cc.Read(req)
+}
diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go
new file mode 100644
index 000000000..1072e2e34
--- /dev/null
+++ b/src/pkg/net/http/httputil/reverseproxy.go
@@ -0,0 +1,167 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP reverse proxy handler
+
+package httputil
+
+import (
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+)
+
+// ReverseProxy is an HTTP Handler that takes an incoming request and
+// sends it to another server, proxying the response back to the
+// client.
+type ReverseProxy struct {
+ // Director must be a function which modifies
+ // the request into a new request to be sent
+ // using Transport. Its response is then copied
+ // back to the original client unmodified.
+ Director func(*http.Request)
+
+ // The transport used to perform proxy requests.
+ // If nil, http.DefaultTransport is used.
+ Transport http.RoundTripper
+
+ // FlushInterval specifies the flush interval
+ // to flush to the client while copying the
+ // response body.
+ // If zero, no periodic flushing is done.
+ FlushInterval time.Duration
+}
+
+func singleJoiningSlash(a, b string) string {
+ aslash := strings.HasSuffix(a, "/")
+ bslash := strings.HasPrefix(b, "/")
+ switch {
+ case aslash && bslash:
+ return a + b[1:]
+ case !aslash && !bslash:
+ return a + "/" + b
+ }
+ return a + b
+}
+
+// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
+// URLs to the scheme, host, and base path provided in target. If the
+// target's path is "/base" and the incoming request was for "/dir",
+// the target request will be for /base/dir.
+func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
+ director := func(req *http.Request) {
+ req.URL.Scheme = target.Scheme
+ req.URL.Host = target.Host
+ req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
+ req.URL.RawQuery = target.RawQuery
+ }
+ return &ReverseProxy{Director: director}
+}
+
+func copyHeader(dst, src http.Header) {
+ for k, vv := range src {
+ for _, v := range vv {
+ dst.Add(k, v)
+ }
+ }
+}
+
+func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ transport := p.Transport
+ if transport == nil {
+ transport = http.DefaultTransport
+ }
+
+ outreq := new(http.Request)
+ *outreq = *req // includes shallow copies of maps, but okay
+
+ p.Director(outreq)
+ outreq.Proto = "HTTP/1.1"
+ outreq.ProtoMajor = 1
+ outreq.ProtoMinor = 1
+ outreq.Close = false
+
+ // Remove the connection header to the backend. We want a
+ // persistent connection, regardless of what the client sent
+ // to us. This is modifying the same underlying map from req
+ // (shallow copied above) so we only copy it if necessary.
+ if outreq.Header.Get("Connection") != "" {
+ outreq.Header = make(http.Header)
+ copyHeader(outreq.Header, req.Header)
+ outreq.Header.Del("Connection")
+ }
+
+ if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ outreq.Header.Set("X-Forwarded-For", clientIp)
+ }
+
+ res, err := transport.RoundTrip(outreq)
+ if err != nil {
+ log.Printf("http: proxy error: %v", err)
+ rw.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+
+ copyHeader(rw.Header(), res.Header)
+
+ rw.WriteHeader(res.StatusCode)
+
+ if res.Body != nil {
+ var dst io.Writer = rw
+ if p.FlushInterval != 0 {
+ if wf, ok := rw.(writeFlusher); ok {
+ dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
+ }
+ }
+ io.Copy(dst, res.Body)
+ }
+}
+
+type writeFlusher interface {
+ io.Writer
+ http.Flusher
+}
+
+type maxLatencyWriter struct {
+ dst writeFlusher
+ latency time.Duration
+
+ lk sync.Mutex // protects init of done, as well Write + Flush
+ done chan bool
+}
+
+func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+ m.lk.Lock()
+ defer m.lk.Unlock()
+ if m.done == nil {
+ m.done = make(chan bool)
+ go m.flushLoop()
+ }
+ n, err = m.dst.Write(p)
+ if err != nil {
+ m.done <- true
+ }
+ return
+}
+
+func (m *maxLatencyWriter) flushLoop() {
+ t := time.NewTicker(m.latency)
+ defer t.Stop()
+ for {
+ select {
+ case <-t.C:
+ m.lk.Lock()
+ m.dst.Flush()
+ m.lk.Unlock()
+ case <-m.done:
+ return
+ }
+ }
+ panic("unreached")
+}
diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go
new file mode 100644
index 000000000..655784b30
--- /dev/null
+++ b/src/pkg/net/http/httputil/reverseproxy_test.go
@@ -0,0 +1,71 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Reverse proxy tests.
+
+package httputil
+
+import (
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "testing"
+)
+
+func TestReverseProxy(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if len(r.TransferEncoding) > 0 {
+ t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
+ }
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if c := r.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got Connection header value %q", c)
+ }
+ if g, e := r.Host, "some-name"; g != e {
+ t.Errorf("backend got Host header %q, want %q", g, e)
+ }
+ w.Header().Set("X-Foo", "bar")
+ http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close")
+ getReq.Close = true
+ res, err := http.DefaultClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
+ t.Errorf("got X-Foo %q; expected %q", g, e)
+ }
+ if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
+ t.Fatalf("got %d SetCookies, want %d", g, e)
+ }
+ if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
+ t.Errorf("unexpected cookie %q", cookie.Name)
+ }
+ bodyBytes, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
diff --git a/src/pkg/net/http/jar.go b/src/pkg/net/http/jar.go
new file mode 100644
index 000000000..2c2caa251
--- /dev/null
+++ b/src/pkg/net/http/jar.go
@@ -0,0 +1,30 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "net/url"
+)
+
+// A CookieJar manages storage and use of cookies in HTTP requests.
+//
+// Implementations of CookieJar must be safe for concurrent use by multiple
+// goroutines.
+type CookieJar interface {
+ // SetCookies handles the receipt of the cookies in a reply for the
+ // given URL. It may or may not choose to save the cookies, depending
+ // on the jar's policy and implementation.
+ SetCookies(u *url.URL, cookies []*Cookie)
+
+ // Cookies returns the cookies to send in a request for the given URL.
+ // It is up to the implementation to honor the standard cookie use
+ // restrictions such as in RFC 6265.
+ Cookies(u *url.URL) []*Cookie
+}
+
+type blackHoleJar struct{}
+
+func (blackHoleJar) SetCookies(u *url.URL, cookies []*Cookie) {}
+func (blackHoleJar) Cookies(u *url.URL) []*Cookie { return nil }
diff --git a/src/pkg/net/http/lex.go b/src/pkg/net/http/lex.go
new file mode 100644
index 000000000..93b67e701
--- /dev/null
+++ b/src/pkg/net/http/lex.go
@@ -0,0 +1,144 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// This file deals with lexical matters of HTTP
+
+func isSeparator(c byte) bool {
+ switch c {
+ case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t':
+ return true
+ }
+ return false
+}
+
+func isSpace(c byte) bool {
+ switch c {
+ case ' ', '\t', '\r', '\n':
+ return true
+ }
+ return false
+}
+
+func isCtl(c byte) bool { return (0 <= c && c <= 31) || c == 127 }
+
+func isChar(c byte) bool { return 0 <= c && c <= 127 }
+
+func isAnyText(c byte) bool { return !isCtl(c) }
+
+func isQdText(c byte) bool { return isAnyText(c) && c != '"' }
+
+func isToken(c byte) bool { return isChar(c) && !isCtl(c) && !isSeparator(c) }
+
+// Valid escaped sequences are not specified in RFC 2616, so for now, we assume
+// that they coincide with the common sense ones used by GO. Malformed
+// characters should probably not be treated as errors by a robust (forgiving)
+// parser, so we replace them with the '?' character.
+func httpUnquotePair(b byte) byte {
+ // skip the first byte, which should always be '\'
+ switch b {
+ case 'a':
+ return '\a'
+ case 'b':
+ return '\b'
+ case 'f':
+ return '\f'
+ case 'n':
+ return '\n'
+ case 'r':
+ return '\r'
+ case 't':
+ return '\t'
+ case 'v':
+ return '\v'
+ case '\\':
+ return '\\'
+ case '\'':
+ return '\''
+ case '"':
+ return '"'
+ }
+ return '?'
+}
+
+// raw must begin with a valid quoted string. Only the first quoted string is
+// parsed and is unquoted in result. eaten is the number of bytes parsed, or -1
+// upon failure.
+func httpUnquote(raw []byte) (eaten int, result string) {
+ buf := make([]byte, len(raw))
+ if raw[0] != '"' {
+ return -1, ""
+ }
+ eaten = 1
+ j := 0 // # of bytes written in buf
+ for i := 1; i < len(raw); i++ {
+ switch b := raw[i]; b {
+ case '"':
+ eaten++
+ buf = buf[0:j]
+ return i + 1, string(buf)
+ case '\\':
+ if len(raw) < i+2 {
+ return -1, ""
+ }
+ buf[j] = httpUnquotePair(raw[i+1])
+ eaten += 2
+ j++
+ i++
+ default:
+ if isQdText(b) {
+ buf[j] = b
+ } else {
+ buf[j] = '?'
+ }
+ eaten++
+ j++
+ }
+ }
+ return -1, ""
+}
+
+// This is a best effort parse, so errors are not returned, instead not all of
+// the input string might be parsed. result is always non-nil.
+func httpSplitFieldValue(fv string) (eaten int, result []string) {
+ result = make([]string, 0, len(fv))
+ raw := []byte(fv)
+ i := 0
+ chunk := ""
+ for i < len(raw) {
+ b := raw[i]
+ switch {
+ case b == '"':
+ eaten, unq := httpUnquote(raw[i:len(raw)])
+ if eaten < 0 {
+ return i, result
+ } else {
+ i += eaten
+ chunk += unq
+ }
+ case isSeparator(b):
+ if chunk != "" {
+ result = result[0 : len(result)+1]
+ result[len(result)-1] = chunk
+ chunk = ""
+ }
+ i++
+ case isToken(b):
+ chunk += string(b)
+ i++
+ case b == '\n' || b == '\r':
+ i++
+ default:
+ chunk += "?"
+ i++
+ }
+ }
+ if chunk != "" {
+ result = result[0 : len(result)+1]
+ result[len(result)-1] = chunk
+ chunk = ""
+ }
+ return i, result
+}
diff --git a/src/pkg/net/http/lex_test.go b/src/pkg/net/http/lex_test.go
new file mode 100644
index 000000000..5386f7534
--- /dev/null
+++ b/src/pkg/net/http/lex_test.go
@@ -0,0 +1,70 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "testing"
+)
+
+type lexTest struct {
+ Raw string
+ Parsed int // # of parsed characters
+ Result []string
+}
+
+var lexTests = []lexTest{
+ {
+ Raw: `"abc"def,:ghi`,
+ Parsed: 13,
+ Result: []string{"abcdef", "ghi"},
+ },
+ // My understanding of the RFC is that escape sequences outside of
+ // quotes are not interpreted?
+ {
+ Raw: `"\t"\t"\t"`,
+ Parsed: 10,
+ Result: []string{"\t", "t\t"},
+ },
+ {
+ Raw: `"\yab"\r\n`,
+ Parsed: 10,
+ Result: []string{"?ab", "r", "n"},
+ },
+ {
+ Raw: "ab\f",
+ Parsed: 3,
+ Result: []string{"ab?"},
+ },
+ {
+ Raw: "\"ab \" c,de f, gh, ij\n\t\r",
+ Parsed: 23,
+ Result: []string{"ab ", "c", "de", "f", "gh", "ij"},
+ },
+}
+
+func min(x, y int) int {
+ if x <= y {
+ return x
+ }
+ return y
+}
+
+func TestSplitFieldValue(t *testing.T) {
+ for k, l := range lexTests {
+ parsed, result := httpSplitFieldValue(l.Raw)
+ if parsed != l.Parsed {
+ t.Errorf("#%d: Parsed %d, expected %d", k, parsed, l.Parsed)
+ }
+ if len(result) != len(l.Result) {
+ t.Errorf("#%d: Result len %d, expected %d", k, len(result), len(l.Result))
+ }
+ for i := 0; i < min(len(result), len(l.Result)); i++ {
+ if result[i] != l.Result[i] {
+ t.Errorf("#%d: %d-th entry mismatch. Have {%s}, expect {%s}",
+ k, i, result[i], l.Result[i])
+ }
+ }
+ }
+}
diff --git a/src/pkg/net/http/pprof/Makefile b/src/pkg/net/http/pprof/Makefile
new file mode 100644
index 000000000..b78fce8e4
--- /dev/null
+++ b/src/pkg/net/http/pprof/Makefile
@@ -0,0 +1,11 @@
+# Copyright 2010 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../../Make.inc
+
+TARG=net/http/pprof
+GOFILES=\
+ pprof.go\
+
+include ../../../../Make.pkg
diff --git a/src/pkg/net/http/pprof/pprof.go b/src/pkg/net/http/pprof/pprof.go
new file mode 100644
index 000000000..21eac4743
--- /dev/null
+++ b/src/pkg/net/http/pprof/pprof.go
@@ -0,0 +1,133 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package pprof serves via its HTTP server runtime profiling data
+// in the format expected by the pprof visualization tool.
+// For more information about pprof, see
+// http://code.google.com/p/google-perftools/.
+//
+// The package is typically only imported for the side effect of
+// registering its HTTP handlers.
+// The handled paths all begin with /debug/pprof/.
+//
+// To use pprof, link this package into your program:
+// import _ "http/pprof"
+//
+// Then use the pprof tool to look at the heap profile:
+//
+// pprof http://localhost:6060/debug/pprof/heap
+//
+// Or to look at a 30-second CPU profile:
+//
+// pprof http://localhost:6060/debug/pprof/profile
+//
+package pprof
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "runtime"
+ "runtime/pprof"
+ "strconv"
+ "strings"
+ "time"
+)
+
+func init() {
+ http.Handle("/debug/pprof/cmdline", http.HandlerFunc(Cmdline))
+ http.Handle("/debug/pprof/profile", http.HandlerFunc(Profile))
+ http.Handle("/debug/pprof/heap", http.HandlerFunc(Heap))
+ http.Handle("/debug/pprof/symbol", http.HandlerFunc(Symbol))
+}
+
+// Cmdline responds with the running program's
+// command line, with arguments separated by NUL bytes.
+// The package initialization registers it as /debug/pprof/cmdline.
+func Cmdline(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ fmt.Fprintf(w, strings.Join(os.Args, "\x00"))
+}
+
+// Heap responds with the pprof-formatted heap profile.
+// The package initialization registers it as /debug/pprof/heap.
+func Heap(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ pprof.WriteHeapProfile(w)
+}
+
+// Profile responds with the pprof-formatted cpu profile.
+// The package initialization registers it as /debug/pprof/profile.
+func Profile(w http.ResponseWriter, r *http.Request) {
+ sec, _ := strconv.ParseInt(r.FormValue("seconds"), 10, 64)
+ if sec == 0 {
+ sec = 30
+ }
+
+ // Set Content Type assuming StartCPUProfile will work,
+ // because if it does it starts writing.
+ w.Header().Set("Content-Type", "application/octet-stream")
+ if err := pprof.StartCPUProfile(w); err != nil {
+ // StartCPUProfile failed, so no writes yet.
+ // Can change header back to text content
+ // and send error code.
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err)
+ return
+ }
+ time.Sleep(time.Duration(sec) * time.Second)
+ pprof.StopCPUProfile()
+}
+
+// Symbol looks up the program counters listed in the request,
+// responding with a table mapping program counters to function names.
+// The package initialization registers it as /debug/pprof/symbol.
+func Symbol(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+
+ // We have to read the whole POST body before
+ // writing any output. Buffer the output here.
+ var buf bytes.Buffer
+
+ // We don't know how many symbols we have, but we
+ // do have symbol information. Pprof only cares whether
+ // this number is 0 (no symbols available) or > 0.
+ fmt.Fprintf(&buf, "num_symbols: 1\n")
+
+ var b *bufio.Reader
+ if r.Method == "POST" {
+ b = bufio.NewReader(r.Body)
+ } else {
+ b = bufio.NewReader(strings.NewReader(r.URL.RawQuery))
+ }
+
+ for {
+ word, err := b.ReadSlice('+')
+ if err == nil {
+ word = word[0 : len(word)-1] // trim +
+ }
+ pc, _ := strconv.ParseUint(string(word), 0, 64)
+ if pc != 0 {
+ f := runtime.FuncForPC(uintptr(pc))
+ if f != nil {
+ fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name())
+ }
+ }
+
+ // Wait until here to check for err; the last
+ // symbol will have an err because it doesn't end in +.
+ if err != nil {
+ if err != io.EOF {
+ fmt.Fprintf(&buf, "reading request: %v\n", err)
+ }
+ break
+ }
+ }
+
+ w.Write(buf.Bytes())
+}
diff --git a/src/pkg/net/http/proxy_test.go b/src/pkg/net/http/proxy_test.go
new file mode 100644
index 000000000..9b320b3aa
--- /dev/null
+++ b/src/pkg/net/http/proxy_test.go
@@ -0,0 +1,48 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "os"
+ "testing"
+)
+
+// TODO(mattn):
+// test ProxyAuth
+
+var UseProxyTests = []struct {
+ host string
+ match bool
+}{
+ // Never proxy localhost:
+ {"localhost:80", false},
+ {"127.0.0.1", false},
+ {"127.0.0.2", false},
+ {"[::1]", false},
+ {"[::2]", true}, // not a loopback address
+
+ {"barbaz.net", false}, // match as .barbaz.net
+ {"foobar.com", false}, // have a port but match
+ {"foofoobar.com", true}, // not match as a part of foobar.com
+ {"baz.com", true}, // not match as a part of barbaz.com
+ {"localhost.net", true}, // not match as suffix of address
+ {"local.localhost", true}, // not match as prefix as address
+ {"barbarbaz.net", true}, // not match because NO_PROXY have a '.'
+ {"www.foobar.com", true}, // not match because NO_PROXY is not .foobar.com
+}
+
+func TestUseProxy(t *testing.T) {
+ oldenv := os.Getenv("NO_PROXY")
+ defer os.Setenv("NO_PROXY", oldenv)
+
+ no_proxy := "foobar.com, .barbaz.net"
+ os.Setenv("NO_PROXY", no_proxy)
+
+ for _, test := range UseProxyTests {
+ if useProxy(test.host+":80") != test.match {
+ t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
+ }
+ }
+}
diff --git a/src/pkg/net/http/range_test.go b/src/pkg/net/http/range_test.go
new file mode 100644
index 000000000..5274a81fa
--- /dev/null
+++ b/src/pkg/net/http/range_test.go
@@ -0,0 +1,57 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "testing"
+)
+
+var ParseRangeTests = []struct {
+ s string
+ length int64
+ r []httpRange
+}{
+ {"", 0, nil},
+ {"foo", 0, nil},
+ {"bytes=", 0, nil},
+ {"bytes=5-4", 10, nil},
+ {"bytes=0-2,5-4", 10, nil},
+ {"bytes=0-9", 10, []httpRange{{0, 10}}},
+ {"bytes=0-", 10, []httpRange{{0, 10}}},
+ {"bytes=5-", 10, []httpRange{{5, 5}}},
+ {"bytes=0-20", 10, []httpRange{{0, 10}}},
+ {"bytes=15-,0-5", 10, nil},
+ {"bytes=-5", 10, []httpRange{{5, 5}}},
+ {"bytes=-15", 10, []httpRange{{0, 10}}},
+ {"bytes=0-499", 10000, []httpRange{{0, 500}}},
+ {"bytes=500-999", 10000, []httpRange{{500, 500}}},
+ {"bytes=-500", 10000, []httpRange{{9500, 500}}},
+ {"bytes=9500-", 10000, []httpRange{{9500, 500}}},
+ {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}},
+ {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}},
+ {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}},
+}
+
+func TestParseRange(t *testing.T) {
+ for _, test := range ParseRangeTests {
+ r := test.r
+ ranges, err := parseRange(test.s, test.length)
+ if err != nil && r != nil {
+ t.Errorf("parseRange(%q) returned error %q", test.s, err)
+ }
+ if len(ranges) != len(r) {
+ t.Errorf("len(parseRange(%q)) = %d, want %d", test.s, len(ranges), len(r))
+ continue
+ }
+ for i := range r {
+ if ranges[i].start != r[i].start {
+ t.Errorf("parseRange(%q)[%d].start = %d, want %d", test.s, i, ranges[i].start, r[i].start)
+ }
+ if ranges[i].length != r[i].length {
+ t.Errorf("parseRange(%q)[%d].length = %d, want %d", test.s, i, ranges[i].length, r[i].length)
+ }
+ }
+ }
+}
diff --git a/src/pkg/net/http/readrequest_test.go b/src/pkg/net/http/readrequest_test.go
new file mode 100644
index 000000000..2e03c658a
--- /dev/null
+++ b/src/pkg/net/http/readrequest_test.go
@@ -0,0 +1,283 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "net/url"
+ "reflect"
+ "testing"
+)
+
+type reqTest struct {
+ Raw string
+ Req *Request
+ Body string
+ Trailer Header
+ Error string
+}
+
+var noError = ""
+var noBody = ""
+var noTrailer Header = nil
+
+var reqTests = []reqTest{
+ // Baseline test; All Request fields included for template use
+ {
+ "GET http://www.techcrunch.com/ HTTP/1.1\r\n" +
+ "Host: www.techcrunch.com\r\n" +
+ "User-Agent: Fake\r\n" +
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+ "Accept-Language: en-us,en;q=0.5\r\n" +
+ "Accept-Encoding: gzip,deflate\r\n" +
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+ "Keep-Alive: 300\r\n" +
+ "Content-Length: 7\r\n" +
+ "Proxy-Connection: keep-alive\r\n\r\n" +
+ "abcdef\n???",
+
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.techcrunch.com",
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
+ "Accept-Language": {"en-us,en;q=0.5"},
+ "Accept-Encoding": {"gzip,deflate"},
+ "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"},
+ "Keep-Alive": {"300"},
+ "Proxy-Connection": {"keep-alive"},
+ "Content-Length": {"7"},
+ "User-Agent": {"Fake"},
+ },
+ Close: false,
+ ContentLength: 7,
+ Host: "www.techcrunch.com",
+ RequestURI: "http://www.techcrunch.com/",
+ },
+
+ "abcdef\n",
+
+ noTrailer,
+ noError,
+ },
+
+ // GET request with no body (the normal case)
+ {
+ "GET / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n\r\n",
+
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "foo.com",
+ RequestURI: "/",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // Tests that we don't parse a path that looks like a
+ // scheme-relative URI as a scheme-relative URI.
+ {
+ "GET //user@host/is/actually/a/path/ HTTP/1.1\r\n" +
+ "Host: test\r\n\r\n",
+
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Path: "//user@host/is/actually/a/path/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "test",
+ RequestURI: "//user@host/is/actually/a/path/",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // Tests a bogus abs_path on the Request-Line (RFC 2616 section 5.1.2)
+ {
+ "GET ../../../../etc/passwd HTTP/1.1\r\n" +
+ "Host: test\r\n\r\n",
+ nil,
+ noBody,
+ noTrailer,
+ "parse ../../../../etc/passwd: invalid URI for request",
+ },
+
+ // Tests missing URL:
+ {
+ "GET HTTP/1.1\r\n" +
+ "Host: test\r\n\r\n",
+ nil,
+ noBody,
+ noTrailer,
+ "parse : empty url",
+ },
+
+ // Tests chunked body with trailer:
+ {
+ "POST / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "3\r\nfoo\r\n" +
+ "3\r\nbar\r\n" +
+ "0\r\n" +
+ "Trailer-Key: Trailer-Value\r\n" +
+ "\r\n",
+ &Request{
+ Method: "POST",
+ URL: &url.URL{
+ Path: "/",
+ },
+ TransferEncoding: []string{"chunked"},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ ContentLength: -1,
+ Host: "foo.com",
+ RequestURI: "/",
+ },
+
+ "foobar",
+ Header{
+ "Trailer-Key": {"Trailer-Value"},
+ },
+ noError,
+ },
+
+ // CONNECT request with domain name:
+ {
+ "CONNECT www.google.com:443 HTTP/1.1\r\n\r\n",
+
+ &Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Host: "www.google.com:443",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "www.google.com:443",
+ RequestURI: "www.google.com:443",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // CONNECT request with IP address:
+ {
+ "CONNECT 127.0.0.1:6060 HTTP/1.1\r\n\r\n",
+
+ &Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Host: "127.0.0.1:6060",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "127.0.0.1:6060",
+ RequestURI: "127.0.0.1:6060",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // CONNECT request for RPC:
+ {
+ "CONNECT /_goRPC_ HTTP/1.1\r\n\r\n",
+
+ &Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Path: "/_goRPC_",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "",
+ RequestURI: "/_goRPC_",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+}
+
+func TestReadRequest(t *testing.T) {
+ for i := range reqTests {
+ tt := &reqTests[i]
+ var braw bytes.Buffer
+ braw.WriteString(tt.Raw)
+ req, err := ReadRequest(bufio.NewReader(&braw))
+ if err != nil {
+ if err.Error() != tt.Error {
+ t.Errorf("#%d: error %q, want error %q", i, err.Error(), tt.Error)
+ }
+ continue
+ }
+ rbody := req.Body
+ req.Body = nil
+ diff(t, fmt.Sprintf("#%d Request", i), req, tt.Req)
+ var bout bytes.Buffer
+ if rbody != nil {
+ _, err := io.Copy(&bout, rbody)
+ if err != nil {
+ t.Fatalf("#%d. copying body: %v", i, err)
+ }
+ rbody.Close()
+ }
+ body := bout.String()
+ if body != tt.Body {
+ t.Errorf("#%d: Body = %q want %q", i, body, tt.Body)
+ }
+ if !reflect.DeepEqual(tt.Trailer, req.Trailer) {
+ t.Errorf("#%d. Trailers differ.\n got: %v\nwant: %v", i, req.Trailer, tt.Trailer)
+ }
+ }
+}
diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go
new file mode 100644
index 000000000..5f8c00086
--- /dev/null
+++ b/src/pkg/net/http/request.go
@@ -0,0 +1,741 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP Request reading and parsing.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "mime"
+ "mime/multipart"
+ "net/textproto"
+ "net/url"
+ "strings"
+)
+
+const (
+ maxValueLength = 4096
+ maxHeaderLines = 1024
+ chunkSize = 4 << 10 // 4 KB chunks
+ defaultMaxMemory = 32 << 20 // 32 MB
+)
+
+// ErrMissingFile is returned by FormFile when the provided file field name
+// is either not present in the request or not a file field.
+var ErrMissingFile = errors.New("http: no such file")
+
+// HTTP request parsing errors.
+type ProtocolError struct {
+ ErrorString string
+}
+
+func (err *ProtocolError) Error() string { return err.ErrorString }
+
+var (
+ ErrHeaderTooLong = &ProtocolError{"header too long"}
+ ErrShortBody = &ProtocolError{"entity body too short"}
+ ErrNotSupported = &ProtocolError{"feature not supported"}
+ ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"}
+ ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"}
+ ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"}
+ ErrMissingBoundary = &ProtocolError{"no multipart boundary param Content-Type"}
+)
+
+type badStringError struct {
+ what string
+ str string
+}
+
+func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
+
+// Headers that Request.Write handles itself and should be skipped.
+var reqWriteExcludeHeader = map[string]bool{
+ "Host": true, // not in Header map anyway
+ "User-Agent": true,
+ "Content-Length": true,
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// A Request represents an HTTP request received by a server
+// or to be sent by a client.
+type Request struct {
+ Method string // GET, POST, PUT, etc.
+ URL *url.URL
+
+ // The protocol version for incoming requests.
+ // Outgoing requests always use HTTP/1.1.
+ Proto string // "HTTP/1.0"
+ ProtoMajor int // 1
+ ProtoMinor int // 0
+
+ // A header maps request lines to their values.
+ // If the header says
+ //
+ // accept-encoding: gzip, deflate
+ // Accept-Language: en-us
+ // Connection: keep-alive
+ //
+ // then
+ //
+ // Header = map[string][]string{
+ // "Accept-Encoding": {"gzip, deflate"},
+ // "Accept-Language": {"en-us"},
+ // "Connection": {"keep-alive"},
+ // }
+ //
+ // HTTP defines that header names are case-insensitive.
+ // The request parser implements this by canonicalizing the
+ // name, making the first character and any characters
+ // following a hyphen uppercase and the rest lowercase.
+ Header Header
+
+ // The message body.
+ Body io.ReadCloser
+
+ // ContentLength records the length of the associated content.
+ // The value -1 indicates that the length is unknown.
+ // Values >= 0 indicate that the given number of bytes may
+ // be read from Body.
+ // For outgoing requests, a value of 0 means unknown if Body is not nil.
+ ContentLength int64
+
+ // TransferEncoding lists the transfer encodings from outermost to
+ // innermost. An empty list denotes the "identity" encoding.
+ // TransferEncoding can usually be ignored; chunked encoding is
+ // automatically added and removed as necessary when sending and
+ // receiving requests.
+ TransferEncoding []string
+
+ // Close indicates whether to close the connection after
+ // replying to this request.
+ Close bool
+
+ // The host on which the URL is sought.
+ // Per RFC 2616, this is either the value of the Host: header
+ // or the host name given in the URL itself.
+ Host string
+
+ // Form contains the parsed form data, including both the URL
+ // field's query parameters and the POST or PUT form data.
+ // This field is only available after ParseForm is called.
+ // The HTTP client ignores Form and uses Body instead.
+ Form url.Values
+
+ // MultipartForm is the parsed multipart form, including file uploads.
+ // This field is only available after ParseMultipartForm is called.
+ // The HTTP client ignores MultipartForm and uses Body instead.
+ MultipartForm *multipart.Form
+
+ // Trailer maps trailer keys to values. Like for Header, if the
+ // response has multiple trailer lines with the same key, they will be
+ // concatenated, delimited by commas.
+ // For server requests, Trailer is only populated after Body has been
+ // closed or fully consumed.
+ // Trailer support is only partially complete.
+ Trailer Header
+
+ // RemoteAddr allows HTTP servers and other software to record
+ // the network address that sent the request, usually for
+ // logging. This field is not filled in by ReadRequest and
+ // has no defined format. The HTTP server in this package
+ // sets RemoteAddr to an "IP:port" address before invoking a
+ // handler.
+ // This field is ignored by the HTTP client.
+ RemoteAddr string
+
+ // RequestURI is the unmodified Request-URI of the
+ // Request-Line (RFC 2616, Section 5.1) as sent by the client
+ // to a server. Usually the URL field should be used instead.
+ // It is an error to set this field in an HTTP client request.
+ RequestURI string
+
+ // TLS allows HTTP servers and other software to record
+ // information about the TLS connection on which the request
+ // was received. This field is not filled in by ReadRequest.
+ // The HTTP server in this package sets the field for
+ // TLS-enabled connections before invoking a handler;
+ // otherwise it leaves the field nil.
+ // This field is ignored by the HTTP client.
+ TLS *tls.ConnectionState
+}
+
+// ProtoAtLeast returns whether the HTTP protocol used
+// in the request is at least major.minor.
+func (r *Request) ProtoAtLeast(major, minor int) bool {
+ return r.ProtoMajor > major ||
+ r.ProtoMajor == major && r.ProtoMinor >= minor
+}
+
+// UserAgent returns the client's User-Agent, if sent in the request.
+func (r *Request) UserAgent() string {
+ return r.Header.Get("User-Agent")
+}
+
+// Cookies parses and returns the HTTP cookies sent with the request.
+func (r *Request) Cookies() []*Cookie {
+ return readCookies(r.Header, "")
+}
+
+var ErrNoCookie = errors.New("http: named cookied not present")
+
+// Cookie returns the named cookie provided in the request or
+// ErrNoCookie if not found.
+func (r *Request) Cookie(name string) (*Cookie, error) {
+ for _, c := range readCookies(r.Header, name) {
+ return c, nil
+ }
+ return nil, ErrNoCookie
+}
+
+// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4,
+// AddCookie does not attach more than one Cookie header field. That
+// means all cookies, if any, are written into the same line,
+// separated by semicolon.
+func (r *Request) AddCookie(c *Cookie) {
+ s := fmt.Sprintf("%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value))
+ if c := r.Header.Get("Cookie"); c != "" {
+ r.Header.Set("Cookie", c+"; "+s)
+ } else {
+ r.Header.Set("Cookie", s)
+ }
+}
+
+// Referer returns the referring URL, if sent in the request.
+//
+// Referer is misspelled as in the request itself, a mistake from the
+// earliest days of HTTP. This value can also be fetched from the
+// Header map as Header["Referer"]; the benefit of making it available
+// as a method is that the compiler can diagnose programs that use the
+// alternate (correct English) spelling req.Referrer() but cannot
+// diagnose programs that use Header["Referrer"].
+func (r *Request) Referer() string {
+ return r.Header.Get("Referer")
+}
+
+// multipartByReader is a sentinel value.
+// Its presence in Request.MultipartForm indicates that parsing of the request
+// body has been handed off to a MultipartReader instead of ParseMultipartFrom.
+var multipartByReader = &multipart.Form{
+ Value: make(map[string][]string),
+ File: make(map[string][]*multipart.FileHeader),
+}
+
+// MultipartReader returns a MIME multipart reader if this is a
+// multipart/form-data POST request, else returns nil and an error.
+// Use this function instead of ParseMultipartForm to
+// process the request body as a stream.
+func (r *Request) MultipartReader() (*multipart.Reader, error) {
+ if r.MultipartForm == multipartByReader {
+ return nil, errors.New("http: MultipartReader called twice")
+ }
+ if r.MultipartForm != nil {
+ return nil, errors.New("http: multipart handled by ParseMultipartForm")
+ }
+ r.MultipartForm = multipartByReader
+ return r.multipartReader()
+}
+
+func (r *Request) multipartReader() (*multipart.Reader, error) {
+ v := r.Header.Get("Content-Type")
+ if v == "" {
+ return nil, ErrNotMultipart
+ }
+ d, params, err := mime.ParseMediaType(v)
+ if err != nil || d != "multipart/form-data" {
+ return nil, ErrNotMultipart
+ }
+ boundary, ok := params["boundary"]
+ if !ok {
+ return nil, ErrMissingBoundary
+ }
+ return multipart.NewReader(r.Body, boundary), nil
+}
+
+// Return value if nonempty, def otherwise.
+func valueOrDefault(value, def string) string {
+ if value != "" {
+ return value
+ }
+ return def
+}
+
+const defaultUserAgent = "Go http package"
+
+// Write writes an HTTP/1.1 request -- header and body -- in wire format.
+// This method consults the following fields of req:
+// Host
+// URL
+// Method (defaults to "GET")
+// Header
+// ContentLength
+// TransferEncoding
+// Body
+//
+// If Body is present, Content-Length is <= 0 and TransferEncoding
+// hasn't been set to "identity", Write adds "Transfer-Encoding:
+// chunked" to the header. Body is closed after it is sent.
+func (req *Request) Write(w io.Writer) error {
+ return req.write(w, false, nil)
+}
+
+// WriteProxy is like Write but writes the request in the form
+// expected by an HTTP proxy. In particular, WriteProxy writes the
+// initial Request-URI line of the request with an absolute URI, per
+// section 5.1.2 of RFC 2616, including the scheme and host. In
+// either case, WriteProxy also writes a Host header, using either
+// req.Host or req.URL.Host.
+func (req *Request) WriteProxy(w io.Writer) error {
+ return req.write(w, true, nil)
+}
+
+// extraHeaders may be nil
+func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) error {
+ host := req.Host
+ if host == "" {
+ if req.URL == nil {
+ return errors.New("http: Request.Write on Request with no Host or URL set")
+ }
+ host = req.URL.Host
+ }
+
+ ruri := req.URL.RequestURI()
+ if usingProxy && req.URL.Scheme != "" && req.URL.Opaque == "" {
+ ruri = req.URL.Scheme + "://" + host + ruri
+ } else if req.Method == "CONNECT" && req.URL.Path == "" {
+ // CONNECT requests normally give just the host and port, not a full URL.
+ ruri = host
+ }
+ // TODO(bradfitz): escape at least newlines in ruri?
+
+ bw := bufio.NewWriter(w)
+ fmt.Fprintf(bw, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+
+ // Header lines
+ fmt.Fprintf(bw, "Host: %s\r\n", host)
+
+ // Use the defaultUserAgent unless the Header contains one, which
+ // may be blank to not send the header.
+ userAgent := defaultUserAgent
+ if req.Header != nil {
+ if ua := req.Header["User-Agent"]; len(ua) > 0 {
+ userAgent = ua[0]
+ }
+ }
+ if userAgent != "" {
+ fmt.Fprintf(bw, "User-Agent: %s\r\n", userAgent)
+ }
+
+ // Process Body,ContentLength,Close,Trailer
+ tw, err := newTransferWriter(req)
+ if err != nil {
+ return err
+ }
+ err = tw.WriteHeader(bw)
+ if err != nil {
+ return err
+ }
+
+ // TODO: split long values? (If so, should share code with Conn.Write)
+ err = req.Header.WriteSubset(bw, reqWriteExcludeHeader)
+ if err != nil {
+ return err
+ }
+
+ if extraHeaders != nil {
+ err = extraHeaders.Write(bw)
+ if err != nil {
+ return err
+ }
+ }
+
+ io.WriteString(bw, "\r\n")
+
+ // Write body and trailer
+ err = tw.WriteBody(bw)
+ if err != nil {
+ return err
+ }
+
+ return bw.Flush()
+}
+
+// Convert decimal at s[i:len(s)] to integer,
+// returning value, string position where the digits stopped,
+// and whether there was a valid number (digits, not too big).
+func atoi(s string, i int) (n, i1 int, ok bool) {
+ const Big = 1000000
+ if i >= len(s) || s[i] < '0' || s[i] > '9' {
+ return 0, 0, false
+ }
+ n = 0
+ for ; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
+ n = n*10 + int(s[i]-'0')
+ if n > Big {
+ return 0, 0, false
+ }
+ }
+ return n, i, true
+}
+
+// ParseHTTPVersion parses a HTTP version string.
+// "HTTP/1.0" returns (1, 0, true).
+func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
+ if len(vers) < 5 || vers[0:5] != "HTTP/" {
+ return 0, 0, false
+ }
+ major, i, ok := atoi(vers, 5)
+ if !ok || i >= len(vers) || vers[i] != '.' {
+ return 0, 0, false
+ }
+ minor, i, ok = atoi(vers, i+1)
+ if !ok || i != len(vers) {
+ return 0, 0, false
+ }
+ return major, minor, true
+}
+
+// NewRequest returns a new Request given a method, URL, and optional body.
+func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
+ u, err := url.Parse(urlStr)
+ if err != nil {
+ return nil, err
+ }
+ rc, ok := body.(io.ReadCloser)
+ if !ok && body != nil {
+ rc = ioutil.NopCloser(body)
+ }
+ req := &Request{
+ Method: method,
+ URL: u,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(Header),
+ Body: rc,
+ Host: u.Host,
+ }
+ if body != nil {
+ switch v := body.(type) {
+ case *strings.Reader:
+ req.ContentLength = int64(v.Len())
+ case *bytes.Buffer:
+ req.ContentLength = int64(v.Len())
+ }
+ }
+
+ return req, nil
+}
+
+// SetBasicAuth sets the request's Authorization header to use HTTP
+// Basic Authentication with the provided username and password.
+//
+// With HTTP Basic Authentication the provided username and password
+// are not encrypted.
+func (r *Request) SetBasicAuth(username, password string) {
+ s := username + ":" + password
+ r.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s)))
+}
+
+// ReadRequest reads and parses a request from b.
+func ReadRequest(b *bufio.Reader) (req *Request, err error) {
+
+ tp := textproto.NewReader(b)
+ req = new(Request)
+
+ // First line: GET /index.html HTTP/1.0
+ var s string
+ if s, err = tp.ReadLine(); err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return nil, err
+ }
+
+ var f []string
+ if f = strings.SplitN(s, " ", 3); len(f) < 3 {
+ return nil, &badStringError{"malformed HTTP request", s}
+ }
+ req.Method, req.RequestURI, req.Proto = f[0], f[1], f[2]
+ rawurl := req.RequestURI
+ var ok bool
+ if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok {
+ return nil, &badStringError{"malformed HTTP version", req.Proto}
+ }
+
+ // CONNECT requests are used two different ways, and neither uses a full URL:
+ // The standard use is to tunnel HTTPS through an HTTP proxy.
+ // It looks like "CONNECT www.google.com:443 HTTP/1.1", and the parameter is
+ // just the authority section of a URL. This information should go in req.URL.Host.
+ //
+ // The net/rpc package also uses CONNECT, but there the parameter is a path
+ // that starts with a slash. It can be parsed with the regular URL parser,
+ // and the path will end up in req.URL.Path, where it needs to be in order for
+ // RPC to work.
+ justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
+ if justAuthority {
+ rawurl = "http://" + rawurl
+ }
+
+ if req.URL, err = url.ParseRequest(rawurl); err != nil {
+ return nil, err
+ }
+
+ if justAuthority {
+ // Strip the bogus "http://" back off.
+ req.URL.Scheme = ""
+ }
+
+ // Subsequent lines: Key: value.
+ mimeHeader, err := tp.ReadMIMEHeader()
+ if err != nil {
+ return nil, err
+ }
+ req.Header = Header(mimeHeader)
+
+ // RFC2616: Must treat
+ // GET /index.html HTTP/1.1
+ // Host: www.google.com
+ // and
+ // GET http://www.google.com/index.html HTTP/1.1
+ // Host: doesntmatter
+ // the same. In the second case, any Host line is ignored.
+ req.Host = req.URL.Host
+ if req.Host == "" {
+ req.Host = req.Header.Get("Host")
+ }
+ req.Header.Del("Host")
+
+ fixPragmaCacheControl(req.Header)
+
+ // TODO: Parse specific header values:
+ // Accept
+ // Accept-Encoding
+ // Accept-Language
+ // Authorization
+ // Cache-Control
+ // Connection
+ // Date
+ // Expect
+ // From
+ // If-Match
+ // If-Modified-Since
+ // If-None-Match
+ // If-Range
+ // If-Unmodified-Since
+ // Max-Forwards
+ // Proxy-Authorization
+ // Referer [sic]
+ // TE (transfer-codings)
+ // Trailer
+ // Transfer-Encoding
+ // Upgrade
+ // User-Agent
+ // Via
+ // Warning
+
+ err = readTransfer(req, b)
+ if err != nil {
+ return nil, err
+ }
+
+ return req, nil
+}
+
+// MaxBytesReader is similar to io.LimitReader but is intended for
+// limiting the size of incoming request bodies. In contrast to
+// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a
+// non-EOF error for a Read beyond the limit, and Closes the
+// underlying reader when its Close method is called.
+//
+// MaxBytesReader prevents clients from accidentally or maliciously
+// sending a large request and wasting server resources.
+func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
+ return &maxBytesReader{w: w, r: r, n: n}
+}
+
+type maxBytesReader struct {
+ w ResponseWriter
+ r io.ReadCloser // underlying reader
+ n int64 // max bytes remaining
+ stopped bool
+}
+
+func (l *maxBytesReader) Read(p []byte) (n int, err error) {
+ if l.n <= 0 {
+ if !l.stopped {
+ l.stopped = true
+ if res, ok := l.w.(*response); ok {
+ res.requestTooLarge()
+ }
+ }
+ return 0, errors.New("http: request body too large")
+ }
+ if int64(len(p)) > l.n {
+ p = p[:l.n]
+ }
+ n, err = l.r.Read(p)
+ l.n -= int64(n)
+ return
+}
+
+func (l *maxBytesReader) Close() error {
+ return l.r.Close()
+}
+
+// ParseForm parses the raw query from the URL.
+//
+// For POST or PUT requests, it also parses the request body as a form.
+// If the request Body's size has not already been limited by MaxBytesReader,
+// the size is capped at 10MB.
+//
+// ParseMultipartForm calls ParseForm automatically.
+// It is idempotent.
+func (r *Request) ParseForm() (err error) {
+ if r.Form != nil {
+ return
+ }
+ if r.URL != nil {
+ r.Form, err = url.ParseQuery(r.URL.RawQuery)
+ }
+ if r.Method == "POST" || r.Method == "PUT" {
+ if r.Body == nil {
+ return errors.New("missing form body")
+ }
+ ct := r.Header.Get("Content-Type")
+ ct, _, err = mime.ParseMediaType(ct)
+ switch {
+ case ct == "application/x-www-form-urlencoded":
+ var reader io.Reader = r.Body
+ maxFormSize := int64(1<<63 - 1)
+ if _, ok := r.Body.(*maxBytesReader); !ok {
+ maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
+ reader = io.LimitReader(r.Body, maxFormSize+1)
+ }
+ b, e := ioutil.ReadAll(reader)
+ if e != nil {
+ if err == nil {
+ err = e
+ }
+ break
+ }
+ if int64(len(b)) > maxFormSize {
+ return errors.New("http: POST too large")
+ }
+ var newValues url.Values
+ newValues, e = url.ParseQuery(string(b))
+ if err == nil {
+ err = e
+ }
+ if r.Form == nil {
+ r.Form = make(url.Values)
+ }
+ // Copy values into r.Form. TODO: make this smoother.
+ for k, vs := range newValues {
+ for _, value := range vs {
+ r.Form.Add(k, value)
+ }
+ }
+ case ct == "multipart/form-data":
+ // handled by ParseMultipartForm (which is calling us, or should be)
+ // TODO(bradfitz): there are too many possible
+ // orders to call too many functions here.
+ // Clean this up and write more tests.
+ // request_test.go contains the start of this,
+ // in TestRequestMultipartCallOrder.
+ }
+ }
+ return err
+}
+
+// ParseMultipartForm parses a request body as multipart/form-data.
+// The whole request body is parsed and up to a total of maxMemory bytes of
+// its file parts are stored in memory, with the remainder stored on
+// disk in temporary files.
+// ParseMultipartForm calls ParseForm if necessary.
+// After one call to ParseMultipartForm, subsequent calls have no effect.
+func (r *Request) ParseMultipartForm(maxMemory int64) error {
+ if r.MultipartForm == multipartByReader {
+ return errors.New("http: multipart handled by MultipartReader")
+ }
+ if r.Form == nil {
+ err := r.ParseForm()
+ if err != nil {
+ return err
+ }
+ }
+ if r.MultipartForm != nil {
+ return nil
+ }
+
+ mr, err := r.multipartReader()
+ if err == ErrNotMultipart {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ f, err := mr.ReadForm(maxMemory)
+ if err != nil {
+ return err
+ }
+ for k, v := range f.Value {
+ r.Form[k] = append(r.Form[k], v...)
+ }
+ r.MultipartForm = f
+
+ return nil
+}
+
+// FormValue returns the first value for the named component of the query.
+// FormValue calls ParseMultipartForm and ParseForm if necessary.
+func (r *Request) FormValue(key string) string {
+ if r.Form == nil {
+ r.ParseMultipartForm(defaultMaxMemory)
+ }
+ if vs := r.Form[key]; len(vs) > 0 {
+ return vs[0]
+ }
+ return ""
+}
+
+// FormFile returns the first file for the provided form key.
+// FormFile calls ParseMultipartForm and ParseForm if necessary.
+func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) {
+ if r.MultipartForm == multipartByReader {
+ return nil, nil, errors.New("http: multipart handled by MultipartReader")
+ }
+ if r.MultipartForm == nil {
+ err := r.ParseMultipartForm(defaultMaxMemory)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if r.MultipartForm != nil && r.MultipartForm.File != nil {
+ if fhs := r.MultipartForm.File[key]; len(fhs) > 0 {
+ f, err := fhs[0].Open()
+ return f, fhs[0], err
+ }
+ }
+ return nil, nil, ErrMissingFile
+}
+
+func (r *Request) expectsContinue() bool {
+ return strings.ToLower(r.Header.Get("Expect")) == "100-continue"
+}
+
+func (r *Request) wantsHttp10KeepAlive() bool {
+ if r.ProtoMajor != 1 || r.ProtoMinor != 0 {
+ return false
+ }
+ return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "keep-alive")
+}
diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go
new file mode 100644
index 000000000..7a3556d03
--- /dev/null
+++ b/src/pkg/net/http/request_test.go
@@ -0,0 +1,283 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "mime/multipart"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "reflect"
+ "regexp"
+ "strings"
+ "testing"
+)
+
+func TestQuery(t *testing.T) {
+ req := &Request{Method: "GET"}
+ req.URL, _ = url.Parse("http://www.google.com/search?q=foo&q=bar")
+ if q := req.FormValue("q"); q != "foo" {
+ t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
+ }
+}
+
+func TestPostQuery(t *testing.T) {
+ req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x",
+ strings.NewReader("z=post&both=y"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+
+ if q := req.FormValue("q"); q != "foo" {
+ t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
+ }
+ if z := req.FormValue("z"); z != "post" {
+ t.Errorf(`req.FormValue("z") = %q, want "post"`, z)
+ }
+ if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"x", "y"}) {
+ t.Errorf(`req.FormValue("both") = %q, want ["x", "y"]`, both)
+ }
+}
+
+type stringMap map[string][]string
+type parseContentTypeTest struct {
+ shouldError bool
+ contentType stringMap
+}
+
+var parseContentTypeTests = []parseContentTypeTest{
+ {false, stringMap{"Content-Type": {"text/plain"}}},
+ // Non-existent keys are not placed. The value nil is illegal.
+ {true, stringMap{}},
+ {true, stringMap{"Content-Type": {"text/plain; boundary="}}},
+ {false, stringMap{"Content-Type": {"application/unknown"}}},
+}
+
+func TestParseFormUnknownContentType(t *testing.T) {
+ for i, test := range parseContentTypeTests {
+ req := &Request{
+ Method: "POST",
+ Header: Header(test.contentType),
+ Body: ioutil.NopCloser(bytes.NewBufferString("body")),
+ }
+ err := req.ParseForm()
+ switch {
+ case err == nil && test.shouldError:
+ t.Errorf("test %d should have returned error", i)
+ case err != nil && !test.shouldError:
+ t.Errorf("test %d should not have returned error, got %v", i, err)
+ }
+ }
+}
+
+func TestMultipartReader(t *testing.T) {
+ req := &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}},
+ Body: ioutil.NopCloser(new(bytes.Buffer)),
+ }
+ multipart, err := req.MultipartReader()
+ if multipart == nil {
+ t.Errorf("expected multipart; error: %v", err)
+ }
+
+ req.Header = Header{"Content-Type": {"text/plain"}}
+ multipart, err = req.MultipartReader()
+ if multipart != nil {
+ t.Errorf("unexpected multipart for text/plain")
+ }
+}
+
+func TestRedirect(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ switch r.URL.Path {
+ case "/":
+ w.Header().Set("Location", "/foo/")
+ w.WriteHeader(StatusSeeOther)
+ case "/foo/":
+ fmt.Fprintf(w, "foo")
+ default:
+ w.WriteHeader(StatusBadRequest)
+ }
+ }))
+ defer ts.Close()
+
+ var end = regexp.MustCompile("/foo/$")
+ r, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r.Body.Close()
+ url := r.Request.URL.String()
+ if r.StatusCode != 200 || !end.MatchString(url) {
+ t.Fatalf("Get got status %d at %q, want 200 matching /foo/$", r.StatusCode, url)
+ }
+}
+
+func TestSetBasicAuth(t *testing.T) {
+ r, _ := NewRequest("GET", "http://example.com/", nil)
+ r.SetBasicAuth("Aladdin", "open sesame")
+ if g, e := r.Header.Get("Authorization"), "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="; g != e {
+ t.Errorf("got header %q, want %q", g, e)
+ }
+}
+
+func TestMultipartRequest(t *testing.T) {
+ // Test that we can read the values and files of a
+ // multipart request with FormValue and FormFile,
+ // and that ParseMultipartForm can be called multiple times.
+ req := newTestMultipartRequest(t)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatal("ParseMultipartForm first call:", err)
+ }
+ defer req.MultipartForm.RemoveAll()
+ validateTestMultipartContents(t, req, false)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatal("ParseMultipartForm second call:", err)
+ }
+ validateTestMultipartContents(t, req, false)
+}
+
+func TestMultipartRequestAuto(t *testing.T) {
+ // Test that FormValue and FormFile automatically invoke
+ // ParseMultipartForm and return the right values.
+ req := newTestMultipartRequest(t)
+ defer func() {
+ if req.MultipartForm != nil {
+ req.MultipartForm.RemoveAll()
+ }
+ }()
+ validateTestMultipartContents(t, req, true)
+}
+
+func TestEmptyMultipartRequest(t *testing.T) {
+ // Test that FormValue and FormFile automatically invoke
+ // ParseMultipartForm and return the right values.
+ req, err := NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Errorf("NewRequest err = %q", err)
+ }
+ testMissingFile(t, req)
+}
+
+func TestRequestMultipartCallOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ _, err := req.MultipartReader()
+ if err != nil {
+ t.Fatalf("MultipartReader: %v", err)
+ }
+ err = req.ParseMultipartForm(1024)
+ if err == nil {
+ t.Errorf("expected an error from ParseMultipartForm after call to MultipartReader")
+ }
+}
+
+func testMissingFile(t *testing.T, req *Request) {
+ f, fh, err := req.FormFile("missing")
+ if f != nil {
+ t.Errorf("FormFile file = %q, want nil", f)
+ }
+ if fh != nil {
+ t.Errorf("FormFile file header = %q, want nil", fh)
+ }
+ if err != ErrMissingFile {
+ t.Errorf("FormFile err = %q, want ErrMissingFile", err)
+ }
+}
+
+func newTestMultipartRequest(t *testing.T) *Request {
+ b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1))
+ req, err := NewRequest("POST", "/", b)
+ if err != nil {
+ t.Fatal("NewRequest:", err)
+ }
+ ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary)
+ req.Header.Set("Content-type", ctype)
+ return req
+}
+
+func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) {
+ if g, e := req.FormValue("texta"), textaValue; g != e {
+ t.Errorf("texta value = %q, want %q", g, e)
+ }
+ if g, e := req.FormValue("textb"), textbValue; g != e {
+ t.Errorf("textb value = %q, want %q", g, e)
+ }
+ if g := req.FormValue("missing"); g != "" {
+ t.Errorf("missing value = %q, want empty string", g)
+ }
+
+ assertMem := func(n string, fd multipart.File) {
+ if _, ok := fd.(*os.File); ok {
+ t.Error(n, " is *os.File, should not be")
+ }
+ }
+ fda := testMultipartFile(t, req, "filea", "filea.txt", fileaContents)
+ defer fda.Close()
+ assertMem("filea", fda)
+ fdb := testMultipartFile(t, req, "fileb", "fileb.txt", filebContents)
+ defer fdb.Close()
+ if allMem {
+ assertMem("fileb", fdb)
+ } else {
+ if _, ok := fdb.(*os.File); !ok {
+ t.Errorf("fileb has unexpected underlying type %T", fdb)
+ }
+ }
+
+ testMissingFile(t, req)
+}
+
+func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File {
+ f, fh, err := req.FormFile(key)
+ if err != nil {
+ t.Fatalf("FormFile(%q): %q", key, err)
+ }
+ if fh.Filename != expectFilename {
+ t.Errorf("filename = %q, want %q", fh.Filename, expectFilename)
+ }
+ var b bytes.Buffer
+ _, err = io.Copy(&b, f)
+ if err != nil {
+ t.Fatal("copying contents:", err)
+ }
+ if g := b.String(); g != expectContent {
+ t.Errorf("contents = %q, want %q", g, expectContent)
+ }
+ return f
+}
+
+const (
+ fileaContents = "This is a test file."
+ filebContents = "Another test file."
+ textaValue = "foo"
+ textbValue = "bar"
+ boundary = `MyBoundary`
+)
+
+const message = `
+--MyBoundary
+Content-Disposition: form-data; name="filea"; filename="filea.txt"
+Content-Type: text/plain
+
+` + fileaContents + `
+--MyBoundary
+Content-Disposition: form-data; name="fileb"; filename="fileb.txt"
+Content-Type: text/plain
+
+` + filebContents + `
+--MyBoundary
+Content-Disposition: form-data; name="texta"
+
+` + textaValue + `
+--MyBoundary
+Content-Disposition: form-data; name="textb"
+
+` + textbValue + `
+--MyBoundary--
+`
diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go
new file mode 100644
index 000000000..fc3186f0c
--- /dev/null
+++ b/src/pkg/net/http/requestwrite_test.go
@@ -0,0 +1,438 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/url"
+ "strings"
+ "testing"
+)
+
+type reqWriteTest struct {
+ Req Request
+ Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body
+
+ // Any of these three may be empty to skip that test.
+ WantWrite string // Request.Write
+ WantProxy string // Request.WriteProxy
+
+ WantError error // wanted error from Request.Write
+}
+
+var reqWriteTests = []reqWriteTest{
+ // HTTP/1.1 => chunked coding; no body; no trailer
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.techcrunch.com",
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
+ "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"},
+ "Accept-Encoding": {"gzip,deflate"},
+ "Accept-Language": {"en-us,en;q=0.5"},
+ "Keep-Alive": {"300"},
+ "Proxy-Connection": {"keep-alive"},
+ "User-Agent": {"Fake"},
+ },
+ Body: nil,
+ Close: false,
+ Host: "www.techcrunch.com",
+ Form: map[string][]string{},
+ },
+
+ WantWrite: "GET / HTTP/1.1\r\n" +
+ "Host: www.techcrunch.com\r\n" +
+ "User-Agent: Fake\r\n" +
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+ "Accept-Encoding: gzip,deflate\r\n" +
+ "Accept-Language: en-us,en;q=0.5\r\n" +
+ "Keep-Alive: 300\r\n" +
+ "Proxy-Connection: keep-alive\r\n\r\n",
+
+ WantProxy: "GET http://www.techcrunch.com/ HTTP/1.1\r\n" +
+ "Host: www.techcrunch.com\r\n" +
+ "User-Agent: Fake\r\n" +
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+ "Accept-Encoding: gzip,deflate\r\n" +
+ "Accept-Language: en-us,en;q=0.5\r\n" +
+ "Keep-Alive: 300\r\n" +
+ "Proxy-Connection: keep-alive\r\n\r\n",
+ },
+ // HTTP/1.1 => chunked coding; body; empty trailer
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+
+ WantProxy: "GET http://www.google.com/search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+ // HTTP/1.1 POST => chunked coding; body; empty trailer
+ {
+ Req: Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: true,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "POST /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+
+ WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+
+ // HTTP/1.1 POST with Content-Length, no chunking
+ {
+ Req: Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: true,
+ ContentLength: 6,
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "POST /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Connection: close\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+
+ WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Connection: close\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+
+ // HTTP/1.1 POST with Content-Length in headers
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("http://example.com/"),
+ Host: "example.com",
+ Header: Header{
+ "Content-Length": []string{"10"}, // ignored
+ },
+ ContentLength: 6,
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+
+ WantProxy: "POST http://example.com/ HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+
+ // default to HTTP/1.1
+ {
+ Req: Request{
+ Method: "GET",
+ URL: mustParseURL("/search"),
+ Host: "www.google.com",
+ },
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "\r\n",
+ },
+
+ // Request with a 0 ContentLength and a 0 byte body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) },
+
+ // RFC 2616 Section 14.13 says Content-Length should be specified
+ // unless body is prohibited by the request method.
+ // Also, nginx expects it for POST and PUT.
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+
+ WantProxy: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+ },
+
+ // Request with a 0 ContentLength and a 1 byte body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) },
+
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("x") + chunk(""),
+
+ WantProxy: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("x") + chunk(""),
+ },
+
+ // Request with a ContentLength of 10 but a 5 byte body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 10, // but we're going to send only 5 bytes
+ },
+ Body: []byte("12345"),
+ WantError: errors.New("http: Request.ContentLength=10 with Body length 5"),
+ },
+
+ // Request with a ContentLength of 4 but an 8 byte body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 4, // but we're going to try to send 8 bytes
+ },
+ Body: []byte("12345678"),
+ WantError: errors.New("http: Request.ContentLength=4 with Body length 8"),
+ },
+
+ // Request with a 5 ContentLength and nil body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 5, // but we'll omit the body
+ },
+ WantError: errors.New("http: Request.ContentLength=5 with nil Body"),
+ },
+
+ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
+ // and doesn't add a User-Agent.
+ {
+ Req: Request{
+ Method: "GET",
+ URL: mustParseURL("/foo"),
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Header: Header{
+ "X-Foo": []string{"X-Bar"},
+ },
+ },
+
+ WantWrite: "GET /foo HTTP/1.1\r\n" +
+ "Host: \r\n" +
+ "User-Agent: Go http package\r\n" +
+ "X-Foo: X-Bar\r\n\r\n",
+ },
+}
+
+func TestRequestWrite(t *testing.T) {
+ for i := range reqWriteTests {
+ tt := &reqWriteTests[i]
+
+ setBody := func() {
+ if tt.Body == nil {
+ return
+ }
+ switch b := tt.Body.(type) {
+ case []byte:
+ tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b))
+ case func() io.ReadCloser:
+ tt.Req.Body = b()
+ }
+ }
+ setBody()
+ if tt.Req.Header == nil {
+ tt.Req.Header = make(Header)
+ }
+
+ var braw bytes.Buffer
+ err := tt.Req.Write(&braw)
+ if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.WantError); g != e {
+ t.Errorf("writing #%d, err = %q, want %q", i, g, e)
+ continue
+ }
+ if err != nil {
+ continue
+ }
+
+ if tt.WantWrite != "" {
+ sraw := braw.String()
+ if sraw != tt.WantWrite {
+ t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantWrite, sraw)
+ continue
+ }
+ }
+
+ if tt.WantProxy != "" {
+ setBody()
+ var praw bytes.Buffer
+ err = tt.Req.WriteProxy(&praw)
+ if err != nil {
+ t.Errorf("WriteProxy #%d: %s", i, err)
+ continue
+ }
+ sraw := praw.String()
+ if sraw != tt.WantProxy {
+ t.Errorf("Test Proxy %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantProxy, sraw)
+ continue
+ }
+ }
+ }
+}
+
+type closeChecker struct {
+ io.Reader
+ closed bool
+}
+
+func (rc *closeChecker) Close() error {
+ rc.closed = true
+ return nil
+}
+
+// TestRequestWriteClosesBody tests that Request.Write does close its request.Body.
+// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer
+// inside a NopCloser, and that it serializes it correctly.
+func TestRequestWriteClosesBody(t *testing.T) {
+ rc := &closeChecker{Reader: strings.NewReader("my body")}
+ req, _ := NewRequest("POST", "http://foo.com/", rc)
+ if req.ContentLength != 0 {
+ t.Errorf("got req.ContentLength %d, want 0", req.ContentLength)
+ }
+ buf := new(bytes.Buffer)
+ req.Write(buf)
+ if !rc.closed {
+ t.Error("body not closed after write")
+ }
+ expected := "POST / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ // TODO: currently we don't buffer before chunking, so we get a
+ // single "m" chunk before the other chunks, as this was the 1-byte
+ // read from our MultiReader where we stiched the Body back together
+ // after sniffing whether the Body was 0 bytes or not.
+ chunk("m") +
+ chunk("y body") +
+ chunk("")
+ if buf.String() != expected {
+ t.Errorf("write:\n got: %s\nwant: %s", buf.String(), expected)
+ }
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil {
+ panic(fmt.Sprintf("Error parsing URL %q: %v", s, err))
+ }
+ return u
+}
diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go
new file mode 100644
index 000000000..ae314b5ac
--- /dev/null
+++ b/src/pkg/net/http/response.go
@@ -0,0 +1,236 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP Response reading and parsing.
+
+package http
+
+import (
+ "bufio"
+ "errors"
+ "io"
+ "net/textproto"
+ "net/url"
+ "strconv"
+ "strings"
+)
+
+var respExcludeHeader = map[string]bool{
+ "Content-Length": true,
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// Response represents the response from an HTTP request.
+//
+type Response struct {
+ Status string // e.g. "200 OK"
+ StatusCode int // e.g. 200
+ Proto string // e.g. "HTTP/1.0"
+ ProtoMajor int // e.g. 1
+ ProtoMinor int // e.g. 0
+
+ // Header maps header keys to values. If the response had multiple
+ // headers with the same key, they will be concatenated, with comma
+ // delimiters. (Section 4.2 of RFC 2616 requires that multiple headers
+ // be semantically equivalent to a comma-delimited sequence.) Values
+ // duplicated by other fields in this struct (e.g., ContentLength) are
+ // omitted from Header.
+ //
+ // Keys in the map are canonicalized (see CanonicalHeaderKey).
+ Header Header
+
+ // Body represents the response body.
+ //
+ // The http Client and Transport guarantee that Body is always
+ // non-nil, even on responses without a body or responses with
+ // a zero-lengthed body.
+ Body io.ReadCloser
+
+ // ContentLength records the length of the associated content. The
+ // value -1 indicates that the length is unknown. Unless RequestMethod
+ // is "HEAD", values >= 0 indicate that the given number of bytes may
+ // be read from Body.
+ ContentLength int64
+
+ // Contains transfer encodings from outer-most to inner-most. Value is
+ // nil, means that "identity" encoding is used.
+ TransferEncoding []string
+
+ // Close records whether the header directed that the connection be
+ // closed after reading Body. The value is advice for clients: neither
+ // ReadResponse nor Response.Write ever closes a connection.
+ Close bool
+
+ // Trailer maps trailer keys to values, in the same
+ // format as the header.
+ Trailer Header
+
+ // The Request that was sent to obtain this Response.
+ // Request's Body is nil (having already been consumed).
+ // This is only populated for Client requests.
+ Request *Request
+}
+
+// Cookies parses and returns the cookies set in the Set-Cookie headers.
+func (r *Response) Cookies() []*Cookie {
+ return readSetCookies(r.Header)
+}
+
+var ErrNoLocation = errors.New("http: no Location header in response")
+
+// Location returns the URL of the response's "Location" header,
+// if present. Relative redirects are resolved relative to
+// the Response's Request. ErrNoLocation is returned if no
+// Location header is present.
+func (r *Response) Location() (*url.URL, error) {
+ lv := r.Header.Get("Location")
+ if lv == "" {
+ return nil, ErrNoLocation
+ }
+ if r.Request != nil && r.Request.URL != nil {
+ return r.Request.URL.Parse(lv)
+ }
+ return url.Parse(lv)
+}
+
+// ReadResponse reads and returns an HTTP response from r. The
+// req parameter specifies the Request that corresponds to
+// this Response. Clients must call resp.Body.Close when finished
+// reading resp.Body. After that call, clients can inspect
+// resp.Trailer to find key/value pairs included in the response
+// trailer.
+func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err error) {
+
+ tp := textproto.NewReader(r)
+ resp = new(Response)
+
+ resp.Request = req
+ resp.Request.Method = strings.ToUpper(resp.Request.Method)
+
+ // Parse the first line of the response.
+ line, err := tp.ReadLine()
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return nil, err
+ }
+ f := strings.SplitN(line, " ", 3)
+ if len(f) < 2 {
+ return nil, &badStringError{"malformed HTTP response", line}
+ }
+ reasonPhrase := ""
+ if len(f) > 2 {
+ reasonPhrase = f[2]
+ }
+ resp.Status = f[1] + " " + reasonPhrase
+ resp.StatusCode, err = strconv.Atoi(f[1])
+ if err != nil {
+ return nil, &badStringError{"malformed HTTP status code", f[1]}
+ }
+
+ resp.Proto = f[0]
+ var ok bool
+ if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok {
+ return nil, &badStringError{"malformed HTTP version", resp.Proto}
+ }
+
+ // Parse the response headers.
+ mimeHeader, err := tp.ReadMIMEHeader()
+ if err != nil {
+ return nil, err
+ }
+ resp.Header = Header(mimeHeader)
+
+ fixPragmaCacheControl(resp.Header)
+
+ err = readTransfer(resp, r)
+ if err != nil {
+ return nil, err
+ }
+
+ return resp, nil
+}
+
+// RFC2616: Should treat
+// Pragma: no-cache
+// like
+// Cache-Control: no-cache
+func fixPragmaCacheControl(header Header) {
+ if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" {
+ if _, presentcc := header["Cache-Control"]; !presentcc {
+ header["Cache-Control"] = []string{"no-cache"}
+ }
+ }
+}
+
+// ProtoAtLeast returns whether the HTTP protocol used
+// in the response is at least major.minor.
+func (r *Response) ProtoAtLeast(major, minor int) bool {
+ return r.ProtoMajor > major ||
+ r.ProtoMajor == major && r.ProtoMinor >= minor
+}
+
+// Writes the response (header, body and trailer) in wire format. This method
+// consults the following fields of resp:
+//
+// StatusCode
+// ProtoMajor
+// ProtoMinor
+// RequestMethod
+// TransferEncoding
+// Trailer
+// Body
+// ContentLength
+// Header, values for non-canonical keys will have unpredictable behavior
+//
+func (resp *Response) Write(w io.Writer) error {
+
+ // RequestMethod should be upper-case
+ if resp.Request != nil {
+ resp.Request.Method = strings.ToUpper(resp.Request.Method)
+ }
+
+ // Status line
+ text := resp.Status
+ if text == "" {
+ var ok bool
+ text, ok = statusText[resp.StatusCode]
+ if !ok {
+ text = "status code " + strconv.Itoa(resp.StatusCode)
+ }
+ }
+ io.WriteString(w, "HTTP/"+strconv.Itoa(resp.ProtoMajor)+".")
+ io.WriteString(w, strconv.Itoa(resp.ProtoMinor)+" ")
+ io.WriteString(w, strconv.Itoa(resp.StatusCode)+" "+text+"\r\n")
+
+ // Process Body,ContentLength,Close,Trailer
+ tw, err := newTransferWriter(resp)
+ if err != nil {
+ return err
+ }
+ err = tw.WriteHeader(w)
+ if err != nil {
+ return err
+ }
+
+ // Rest of header
+ err = resp.Header.WriteSubset(w, respExcludeHeader)
+ if err != nil {
+ return err
+ }
+
+ // End-of-header
+ io.WriteString(w, "\r\n")
+
+ // Write body and trailer
+ err = tw.WriteBody(w)
+ if err != nil {
+ return err
+ }
+
+ // Success
+ return nil
+}
diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go
new file mode 100644
index 000000000..e5d01698e
--- /dev/null
+++ b/src/pkg/net/http/response_test.go
@@ -0,0 +1,448 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "compress/gzip"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/url"
+ "reflect"
+ "testing"
+)
+
+type respTest struct {
+ Raw string
+ Resp Response
+ Body string
+}
+
+func dummyReq(method string) *Request {
+ return &Request{Method: method}
+}
+
+var respTests = []respTest{
+ // Unchunked response without Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Connection": {"close"}, // TODO(rsc): Delete?
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // Unchunked HTTP/1.1 response without Content-Length or
+ // Connection headers.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Request: dummyReq("GET"),
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // Unchunked HTTP/1.1 204 response without Content-Length.
+ {
+ "HTTP/1.1 204 No Content\r\n" +
+ "\r\n" +
+ "Body should not be read!\n",
+
+ Response{
+ Status: "204 No Content",
+ StatusCode: 204,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Request: dummyReq("GET"),
+ Close: false,
+ ContentLength: 0,
+ },
+
+ "",
+ },
+
+ // Unchunked response with Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Content-Length: 10\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Connection": {"close"}, // TODO(rsc): Delete?
+ "Content-Length": {"10"}, // TODO(rsc): Delete?
+ },
+ Close: true,
+ ContentLength: 10,
+ },
+
+ "Body here\n",
+ },
+
+ // Chunked response without Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n" +
+ "0a\r\n" +
+ "Body here\n\r\n" +
+ "09\r\n" +
+ "continued\r\n" +
+ "0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ "Body here\ncontinued",
+ },
+
+ // Chunked response with Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "Content-Length: 10\r\n" +
+ "\r\n" +
+ "0a\r\n" +
+ "Body here\n" +
+ "0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1, // TODO(rsc): Fix?
+ TransferEncoding: []string{"chunked"},
+ },
+
+ "Body here\n",
+ },
+
+ // Chunked response in response to a HEAD request (the "chunked" should
+ // be ignored, as HEAD responses never have bodies)
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("HEAD"),
+ Header: Header{},
+ Close: true,
+ ContentLength: 0,
+ },
+
+ "",
+ },
+
+ // explicit Content-Length of 0.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Length": {"0"},
+ },
+ Close: false,
+ ContentLength: 0,
+ },
+
+ "",
+ },
+
+ // Status line without a Reason-Phrase, but trailing space.
+ // (permitted by RFC 2616)
+ {
+ "HTTP/1.0 303 \r\n\r\n",
+ Response{
+ Status: "303 ",
+ StatusCode: 303,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // Status line without a Reason-Phrase, and no trailing space.
+ // (not permitted by RFC 2616, but we'll accept it anyway)
+ {
+ "HTTP/1.0 303\r\n\r\n",
+ Response{
+ Status: "303 ",
+ StatusCode: 303,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+}
+
+func TestReadResponse(t *testing.T) {
+ for i := range respTests {
+ tt := &respTests[i]
+ var braw bytes.Buffer
+ braw.WriteString(tt.Raw)
+ resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.Request)
+ if err != nil {
+ t.Errorf("#%d: %s", i, err)
+ continue
+ }
+ rbody := resp.Body
+ resp.Body = nil
+ diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp)
+ var bout bytes.Buffer
+ if rbody != nil {
+ io.Copy(&bout, rbody)
+ rbody.Close()
+ }
+ body := bout.String()
+ if body != tt.Body {
+ t.Errorf("#%d: Body = %q want %q", i, body, tt.Body)
+ }
+ }
+}
+
+var readResponseCloseInMiddleTests = []struct {
+ chunked, compressed bool
+}{
+ {false, false},
+ {true, false},
+ {true, true},
+}
+
+// TestReadResponseCloseInMiddle tests that closing a body after
+// reading only part of its contents advances the read to the end of
+// the request, right up until the next request.
+func TestReadResponseCloseInMiddle(t *testing.T) {
+ for _, test := range readResponseCloseInMiddleTests {
+ fatalf := func(format string, args ...interface{}) {
+ args = append([]interface{}{test.chunked, test.compressed}, args...)
+ t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...)
+ }
+ checkErr := func(err error, msg string) {
+ if err == nil {
+ return
+ }
+ fatalf(msg+": %v", err)
+ }
+ var buf bytes.Buffer
+ buf.WriteString("HTTP/1.1 200 OK\r\n")
+ if test.chunked {
+ buf.WriteString("Transfer-Encoding: chunked\r\n")
+ } else {
+ buf.WriteString("Content-Length: 1000000\r\n")
+ }
+ var wr io.Writer = &buf
+ if test.chunked {
+ wr = newChunkedWriter(wr)
+ }
+ if test.compressed {
+ buf.WriteString("Content-Encoding: gzip\r\n")
+ var err error
+ wr, err = gzip.NewWriter(wr)
+ checkErr(err, "gzip.NewWriter")
+ }
+ buf.WriteString("\r\n")
+
+ chunk := bytes.Repeat([]byte{'x'}, 1000)
+ for i := 0; i < 1000; i++ {
+ if test.compressed {
+ // Otherwise this compresses too well.
+ _, err := io.ReadFull(rand.Reader, chunk)
+ checkErr(err, "rand.Reader ReadFull")
+ }
+ wr.Write(chunk)
+ }
+ if test.compressed {
+ err := wr.(*gzip.Compressor).Close()
+ checkErr(err, "compressor close")
+ }
+ if test.chunked {
+ buf.WriteString("0\r\n\r\n")
+ }
+ buf.WriteString("Next Request Here")
+
+ bufr := bufio.NewReader(&buf)
+ resp, err := ReadResponse(bufr, dummyReq("GET"))
+ checkErr(err, "ReadResponse")
+ expectedLength := int64(-1)
+ if !test.chunked {
+ expectedLength = 1000000
+ }
+ if resp.ContentLength != expectedLength {
+ fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength)
+ }
+ if resp.Body == nil {
+ fatalf("nil body")
+ }
+ if test.compressed {
+ gzReader, err := gzip.NewReader(resp.Body)
+ checkErr(err, "gzip.NewReader")
+ resp.Body = &readFirstCloseBoth{gzReader, resp.Body}
+ }
+
+ rbuf := make([]byte, 2500)
+ n, err := io.ReadFull(resp.Body, rbuf)
+ checkErr(err, "2500 byte ReadFull")
+ if n != 2500 {
+ fatalf("ReadFull only read %d bytes", n)
+ }
+ if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) {
+ fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf))
+ }
+ resp.Body.Close()
+
+ rest, err := ioutil.ReadAll(bufr)
+ checkErr(err, "ReadAll on remainder")
+ if e, g := "Next Request Here", string(rest); e != g {
+ fatalf("remainder = %q, expected %q", g, e)
+ }
+ }
+}
+
+func diff(t *testing.T, prefix string, have, want interface{}) {
+ hv := reflect.ValueOf(have).Elem()
+ wv := reflect.ValueOf(want).Elem()
+ if hv.Type() != wv.Type() {
+ t.Errorf("%s: type mismatch %v want %v", prefix, hv.Type(), wv.Type())
+ }
+ for i := 0; i < hv.NumField(); i++ {
+ hf := hv.Field(i).Interface()
+ wf := wv.Field(i).Interface()
+ if !reflect.DeepEqual(hf, wf) {
+ t.Errorf("%s: %s = %v want %v", prefix, hv.Type().Field(i).Name, hf, wf)
+ }
+ }
+}
+
+type responseLocationTest struct {
+ location string // Response's Location header or ""
+ requrl string // Response.Request.URL or ""
+ want string
+ wantErr error
+}
+
+var responseLocationTests = []responseLocationTest{
+ {"/foo", "http://bar.com/baz", "http://bar.com/foo", nil},
+ {"http://foo.com/", "http://bar.com/baz", "http://foo.com/", nil},
+ {"", "http://bar.com/baz", "", ErrNoLocation},
+}
+
+func TestLocationResponse(t *testing.T) {
+ for i, tt := range responseLocationTests {
+ res := new(Response)
+ res.Header = make(Header)
+ res.Header.Set("Location", tt.location)
+ if tt.requrl != "" {
+ res.Request = &Request{}
+ var err error
+ res.Request.URL, err = url.Parse(tt.requrl)
+ if err != nil {
+ t.Fatalf("bad test URL %q: %v", tt.requrl, err)
+ }
+ }
+
+ got, err := res.Location()
+ if tt.wantErr != nil {
+ if err == nil {
+ t.Errorf("%d. err=nil; want %q", i, tt.wantErr)
+ continue
+ }
+ if g, e := err.Error(), tt.wantErr.Error(); g != e {
+ t.Errorf("%d. err=%q; want %q", i, g, e)
+ continue
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("%d. err=%q", i, err)
+ continue
+ }
+ if g, e := got.String(), tt.want; g != e {
+ t.Errorf("%d. Location=%q; want %q", i, g, e)
+ }
+ }
+}
diff --git a/src/pkg/net/http/responsewrite_test.go b/src/pkg/net/http/responsewrite_test.go
new file mode 100644
index 000000000..f8e63acf4
--- /dev/null
+++ b/src/pkg/net/http/responsewrite_test.go
@@ -0,0 +1,109 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "io/ioutil"
+ "testing"
+)
+
+type respWriteTest struct {
+ Resp Response
+ Raw string
+}
+
+var respWriteTests = []respWriteTest{
+ // HTTP/1.0, identity coding; no trailer
+ {
+ Response{
+ StatusCode: 503,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ ContentLength: 6,
+ },
+
+ "HTTP/1.0 503 Service Unavailable\r\n" +
+ "Content-Length: 6\r\n\r\n" +
+ "abcdef",
+ },
+ // Unchunked response without Content-Length.
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ ContentLength: -1,
+ },
+ "HTTP/1.0 200 OK\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+ // HTTP/1.1, chunked coding; empty trailer; close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ ContentLength: 6,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
+ },
+
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "6\r\nabcdef\r\n0\r\n\r\n",
+ },
+
+ // Header value with a newline character (Issue 914).
+ // Also tests removal of leading and trailing whitespace.
+ {
+ Response{
+ StatusCode: 204,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Foo": []string{" Bar\nBaz "},
+ },
+ Body: nil,
+ ContentLength: 0,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
+ },
+
+ "HTTP/1.1 204 No Content\r\n" +
+ "Connection: close\r\n" +
+ "Foo: Bar Baz\r\n" +
+ "\r\n",
+ },
+}
+
+func TestResponseWrite(t *testing.T) {
+ for i := range respWriteTests {
+ tt := &respWriteTests[i]
+ var braw bytes.Buffer
+ err := tt.Resp.Write(&braw)
+ if err != nil {
+ t.Errorf("error writing #%d: %s", i, err)
+ continue
+ }
+ sraw := braw.String()
+ if sraw != tt.Raw {
+ t.Errorf("Test %d, expecting:\n%q\nGot:\n%q\n", i, tt.Raw, sraw)
+ continue
+ }
+ }
+}
diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go
new file mode 100644
index 000000000..147c216ec
--- /dev/null
+++ b/src/pkg/net/http/serve_test.go
@@ -0,0 +1,1182 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// End-to-end serving tests
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/http/httputil"
+ "net/url"
+ "os"
+ "reflect"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+)
+
+type dummyAddr string
+type oneConnListener struct {
+ conn net.Conn
+}
+
+func (l *oneConnListener) Accept() (c net.Conn, err error) {
+ c = l.conn
+ if c == nil {
+ err = io.EOF
+ return
+ }
+ err = nil
+ l.conn = nil
+ return
+}
+
+func (l *oneConnListener) Close() error {
+ return nil
+}
+
+func (l *oneConnListener) Addr() net.Addr {
+ return dummyAddr("test-address")
+}
+
+func (a dummyAddr) Network() string {
+ return string(a)
+}
+
+func (a dummyAddr) String() string {
+ return string(a)
+}
+
+type testConn struct {
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+}
+
+func (c *testConn) Read(b []byte) (int, error) {
+ return c.readBuf.Read(b)
+}
+
+func (c *testConn) Write(b []byte) (int, error) {
+ return c.writeBuf.Write(b)
+}
+
+func (c *testConn) Close() error {
+ return nil
+}
+
+func (c *testConn) LocalAddr() net.Addr {
+ return dummyAddr("local-addr")
+}
+
+func (c *testConn) RemoteAddr() net.Addr {
+ return dummyAddr("remote-addr")
+}
+
+func (c *testConn) SetDeadline(t time.Time) error {
+ return nil
+}
+
+func (c *testConn) SetReadDeadline(t time.Time) error {
+ return nil
+}
+
+func (c *testConn) SetWriteDeadline(t time.Time) error {
+ return nil
+}
+
+func TestConsumingBodyOnNextConn(t *testing.T) {
+ conn := new(testConn)
+ for i := 0; i < 2; i++ {
+ conn.readBuf.Write([]byte(
+ "POST / HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "Content-Length: 11\r\n" +
+ "\r\n" +
+ "foo=1&bar=1"))
+ }
+
+ reqNum := 0
+ ch := make(chan *Request)
+ servech := make(chan error)
+ listener := &oneConnListener{conn}
+ handler := func(res ResponseWriter, req *Request) {
+ reqNum++
+ ch <- req
+ }
+
+ go func() {
+ servech <- Serve(listener, HandlerFunc(handler))
+ }()
+
+ var req *Request
+ req = <-ch
+ if req == nil {
+ t.Fatal("Got nil first request.")
+ }
+ if req.Method != "POST" {
+ t.Errorf("For request #1's method, got %q; expected %q",
+ req.Method, "POST")
+ }
+
+ req = <-ch
+ if req == nil {
+ t.Fatal("Got nil first request.")
+ }
+ if req.Method != "POST" {
+ t.Errorf("For request #2's method, got %q; expected %q",
+ req.Method, "POST")
+ }
+
+ if serveerr := <-servech; serveerr != io.EOF {
+ t.Errorf("Serve returned %q; expected EOF", serveerr)
+ }
+}
+
+type stringHandler string
+
+func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ w.Header().Set("Result", string(s))
+}
+
+var handlers = []struct {
+ pattern string
+ msg string
+}{
+ {"/", "Default"},
+ {"/someDir/", "someDir"},
+ {"someHost.com/someDir/", "someHost.com/someDir"},
+}
+
+var vtests = []struct {
+ url string
+ expected string
+}{
+ {"http://localhost/someDir/apage", "someDir"},
+ {"http://localhost/otherDir/apage", "Default"},
+ {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
+ {"http://otherHost.com/someDir/apage", "someDir"},
+ {"http://otherHost.com/aDir/apage", "Default"},
+}
+
+func TestHostHandlers(t *testing.T) {
+ for _, h := range handlers {
+ Handle(h.pattern, stringHandler(h.msg))
+ }
+ ts := httptest.NewServer(nil)
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ cc := httputil.NewClientConn(conn, nil)
+ for _, vt := range vtests {
+ var r *Response
+ var req Request
+ if req.URL, err = url.Parse(vt.url); err != nil {
+ t.Errorf("cannot parse url: %v", err)
+ continue
+ }
+ if err := cc.Write(&req); err != nil {
+ t.Errorf("writing request: %v", err)
+ continue
+ }
+ r, err := cc.Read(&req)
+ if err != nil {
+ t.Errorf("reading response: %v", err)
+ continue
+ }
+ s := r.Header.Get("Result")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ }
+}
+
+// Tests for http://code.google.com/p/go/issues/detail?id=900
+func TestMuxRedirectLeadingSlashes(t *testing.T) {
+ paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
+ for _, path := range paths {
+ req, err := ReadRequest(bufio.NewReader(bytes.NewBufferString("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
+ if err != nil {
+ t.Errorf("%s", err)
+ }
+ mux := NewServeMux()
+ resp := httptest.NewRecorder()
+
+ mux.ServeHTTP(resp, req)
+
+ if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
+ t.Errorf("Expected Location header set to %q; got %q", expected, loc)
+ return
+ }
+
+ if code, expected := resp.Code, StatusMovedPermanently; code != expected {
+ t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
+ return
+ }
+ }
+}
+
+func TestServerTimeouts(t *testing.T) {
+ // TODO(bradfitz): convert this to use httptest.Server
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("listen error: %v", err)
+ }
+ addr, _ := l.Addr().(*net.TCPAddr)
+
+ reqNum := 0
+ handler := HandlerFunc(func(res ResponseWriter, req *Request) {
+ reqNum++
+ fmt.Fprintf(res, "req=%d", reqNum)
+ })
+
+ const second = 1000000000 /* nanos */
+ server := &Server{Handler: handler, ReadTimeout: 0.25 * second, WriteTimeout: 0.25 * second}
+ go server.Serve(l)
+
+ url := fmt.Sprintf("http://%s/", addr)
+
+ // Hit the HTTP server successfully.
+ tr := &Transport{DisableKeepAlives: true} // they interfere with this test
+ c := &Client{Transport: tr}
+ r, err := c.Get(url)
+ if err != nil {
+ t.Fatalf("http Get #1: %v", err)
+ }
+ got, _ := ioutil.ReadAll(r.Body)
+ expected := "req=1"
+ if string(got) != expected {
+ t.Errorf("Unexpected response for request #1; got %q; expected %q",
+ string(got), expected)
+ }
+
+ // Slow client that should timeout.
+ t1 := time.Now()
+ conn, err := net.Dial("tcp", addr.String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ buf := make([]byte, 1)
+ n, err := conn.Read(buf)
+ latency := time.Now().Sub(t1)
+ if n != 0 || err != io.EOF {
+ t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
+ }
+ if latency < 200*time.Millisecond /* fudge from 0.25 above */ {
+ t.Errorf("got EOF after %s, want >= %s", latency, 200*time.Millisecond)
+ }
+
+ // Hit the HTTP server successfully again, verifying that the
+ // previous slow connection didn't run our handler. (that we
+ // get "req=2", not "req=3")
+ r, err = Get(url)
+ if err != nil {
+ t.Fatalf("http Get #2: %v", err)
+ }
+ got, _ = ioutil.ReadAll(r.Body)
+ expected = "req=2"
+ if string(got) != expected {
+ t.Errorf("Get #2 got %q, want %q", string(got), expected)
+ }
+
+ l.Close()
+}
+
+// TestIdentityResponse verifies that a handler can unset
+func TestIdentityResponse(t *testing.T) {
+ handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
+ rw.Header().Set("Content-Length", "3")
+ rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
+ switch {
+ case req.FormValue("overwrite") == "1":
+ _, err := rw.Write([]byte("foo TOO LONG"))
+ if err != ErrContentLength {
+ t.Errorf("expected ErrContentLength; got %v", err)
+ }
+ case req.FormValue("underwrite") == "1":
+ rw.Header().Set("Content-Length", "500")
+ rw.Write([]byte("too short"))
+ default:
+ rw.Write([]byte("foo"))
+ }
+ })
+
+ ts := httptest.NewServer(handler)
+ defer ts.Close()
+
+ // Note: this relies on the assumption (which is true) that
+ // Get sends HTTP/1.1 or greater requests. Otherwise the
+ // server wouldn't have the choice to send back chunked
+ // responses.
+ for _, te := range []string{"", "identity"} {
+ url := ts.URL + "/?te=" + te
+ res, err := Get(url)
+ if err != nil {
+ t.Fatalf("error with Get of %s: %v", url, err)
+ }
+ if cl, expected := res.ContentLength, int64(3); cl != expected {
+ t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
+ }
+ if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
+ t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
+ }
+ if tl, expected := len(res.TransferEncoding), 0; tl != expected {
+ t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
+ url, expected, tl, res.TransferEncoding)
+ }
+ res.Body.Close()
+ }
+
+ // Verify that ErrContentLength is returned
+ url := ts.URL + "/?overwrite=1"
+ _, err := Get(url)
+ if err != nil {
+ t.Fatalf("error with Get of %s: %v", url, err)
+ }
+ // Verify that the connection is closed when the declared Content-Length
+ // is larger than what the handler wrote.
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ if err != nil {
+ t.Fatalf("error writing: %v", err)
+ }
+
+ // The ReadAll will hang for a failing test, so use a Timer to
+ // fail explicitly.
+ goTimeout(t, 2*time.Second, func() {
+ got, _ := ioutil.ReadAll(conn)
+ expectedSuffix := "\r\n\r\ntoo short"
+ if !strings.HasSuffix(string(got), expectedSuffix) {
+ t.Errorf("Expected output to end with %q; got response body %q",
+ expectedSuffix, string(got))
+ }
+ })
+}
+
+func testTcpConnectionCloses(t *testing.T, req string, h Handler) {
+ s := httptest.NewServer(h)
+ defer s.Close()
+
+ conn, err := net.Dial("tcp", s.Listener.Addr().String())
+ if err != nil {
+ t.Fatal("dial error:", err)
+ }
+ defer conn.Close()
+
+ _, err = fmt.Fprint(conn, req)
+ if err != nil {
+ t.Fatal("print error:", err)
+ }
+
+ r := bufio.NewReader(conn)
+ _, err = ReadResponse(r, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal("ReadResponse error:", err)
+ }
+
+ success := make(chan bool)
+ go func() {
+ select {
+ case <-time.After(5 * time.Second):
+ t.Fatal("body not closed after 5s")
+ case <-success:
+ }
+ }()
+
+ _, err = ioutil.ReadAll(r)
+ if err != nil {
+ t.Fatal("read error:", err)
+ }
+
+ success <- true
+}
+
+// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive.
+func TestServeHTTP10Close(t *testing.T) {
+ testTcpConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "testdata/file")
+ }))
+}
+
+// TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close,
+// even for HTTP/1.1 requests.
+func TestHandlersCanSetConnectionClose11(t *testing.T) {
+ testTcpConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ }))
+}
+
+func TestHandlersCanSetConnectionClose10(t *testing.T) {
+ testTcpConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ }))
+}
+
+func TestSetsRemoteAddr(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%s", r.RemoteAddr)
+ }))
+ defer ts.Close()
+
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("ReadAll error: %v", err)
+ }
+ ip := string(body)
+ if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
+ t.Fatalf("Expected local addr; got %q", ip)
+ }
+}
+
+func TestChunkedResponseHeaders(t *testing.T) {
+ log.SetOutput(ioutil.Discard) // is noisy otherwise
+ defer log.SetOutput(os.Stderr)
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
+ fmt.Fprintf(w, "I am a chunked response.")
+ }))
+ defer ts.Close()
+
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ if g, e := res.ContentLength, int64(-1); g != e {
+ t.Errorf("expected ContentLength of %d; got %d", e, g)
+ }
+ if g, e := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(g, e) {
+ t.Errorf("expected TransferEncoding of %v; got %v", e, g)
+ }
+ if _, haveCL := res.Header["Content-Length"]; haveCL {
+ t.Errorf("Unexpected Content-Length")
+ }
+}
+
+// Test304Responses verifies that 304s don't declare that they're
+// chunking in their response headers and aren't allowed to produce
+// output.
+func Test304Responses(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNotModified)
+ _, err := w.Write([]byte("illegal body"))
+ if err != ErrBodyNotAllowed {
+ t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
+ }
+ }))
+ defer ts.Close()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(res.TransferEncoding) > 0 {
+ t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(body) > 0 {
+ t.Errorf("got unexpected body %q", string(body))
+ }
+}
+
+// TestHeadResponses verifies that responses to HEAD requests don't
+// declare that they're chunking in their response headers and aren't
+// allowed to produce output.
+func TestHeadResponses(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("Ignored body"))
+ if err != ErrBodyNotAllowed {
+ t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
+ }
+
+ // Also exercise the ReaderFrom path
+ _, err = io.Copy(w, strings.NewReader("Ignored body"))
+ if err != ErrBodyNotAllowed {
+ t.Errorf("on Copy, expected ErrBodyNotAllowed, got %v", err)
+ }
+ }))
+ defer ts.Close()
+ res, err := Head(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(res.TransferEncoding) > 0 {
+ t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(body) > 0 {
+ t.Errorf("got unexpected body %q", string(body))
+ }
+}
+
+func TestTLSHandshakeTimeout(t *testing.T) {
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ ts.Config.ReadTimeout = 250 * time.Millisecond
+ ts.StartTLS()
+ defer ts.Close()
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+ goTimeout(t, 10*time.Second, func() {
+ var buf [1]byte
+ n, err := conn.Read(buf[:])
+ if err == nil || n != 0 {
+ t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
+ }
+ })
+}
+
+func TestTLSServer(t *testing.T) {
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.TLS != nil {
+ w.Header().Set("X-TLS-Set", "true")
+ if r.TLS.HandshakeComplete {
+ w.Header().Set("X-TLS-HandshakeComplete", "true")
+ }
+ }
+ }))
+ defer ts.Close()
+
+ // Connect an idle TCP connection to this server before we run
+ // our real tests. This idle connection used to block forever
+ // in the TLS handshake, preventing future connections from
+ // being accepted. It may prevent future accidental blocking
+ // in newConn.
+ idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer idleConn.Close()
+ goTimeout(t, 10*time.Second, func() {
+ if !strings.HasPrefix(ts.URL, "https://") {
+ t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
+ return
+ }
+ noVerifyTransport := &Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ }
+ client := &Client{Transport: noVerifyTransport}
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if res == nil {
+ t.Errorf("got nil Response")
+ return
+ }
+ defer res.Body.Close()
+ if res.Header.Get("X-TLS-Set") != "true" {
+ t.Errorf("expected X-TLS-Set response header")
+ return
+ }
+ if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
+ t.Errorf("expected X-TLS-HandshakeComplete header")
+ }
+ })
+}
+
+type serverExpectTest struct {
+ contentLength int // of request body
+ expectation string // e.g. "100-continue"
+ readBody bool // whether handler should read the body (if false, sends StatusUnauthorized)
+ expectedResponse string // expected substring in first line of http response
+}
+
+var serverExpectTests = []serverExpectTest{
+ // Normal 100-continues, case-insensitive.
+ {100, "100-continue", true, "100 Continue"},
+ {100, "100-cOntInUE", true, "100 Continue"},
+
+ // No 100-continue.
+ {100, "", true, "200 OK"},
+
+ // 100-continue but requesting client to deny us,
+ // so it never reads the body.
+ {100, "100-continue", false, "401 Unauthorized"},
+ // Likewise without 100-continue:
+ {100, "", false, "401 Unauthorized"},
+
+ // Non-standard expectations are failures
+ {0, "a-pony", false, "417 Expectation Failed"},
+
+ // Expect-100 requested but no body
+ {0, "100-continue", true, "400 Bad Request"},
+}
+
+// Tests that the server responds to the "Expect" request header
+// correctly.
+func TestServerExpect(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Note using r.FormValue("readbody") because for POST
+ // requests that would read from r.Body, which we only
+ // conditionally want to do.
+ if strings.Contains(r.URL.RawQuery, "readbody=true") {
+ ioutil.ReadAll(r.Body)
+ w.Write([]byte("Hi"))
+ } else {
+ w.WriteHeader(StatusUnauthorized)
+ }
+ }))
+ defer ts.Close()
+
+ runTest := func(test serverExpectTest) {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+ sendf := func(format string, args ...interface{}) {
+ _, err := fmt.Fprintf(conn, format, args...)
+ if err != nil {
+ t.Fatalf("On test %#v, error writing %q: %v", test, format, err)
+ }
+ }
+ go func() {
+ sendf("POST /?readbody=%v HTTP/1.1\r\n"+
+ "Connection: close\r\n"+
+ "Content-Length: %d\r\n"+
+ "Expect: %s\r\nHost: foo\r\n\r\n",
+ test.readBody, test.contentLength, test.expectation)
+ if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" {
+ body := strings.Repeat("A", test.contentLength)
+ sendf(body)
+ }
+ }()
+ bufr := bufio.NewReader(conn)
+ line, err := bufr.ReadString('\n')
+ if err != nil {
+ t.Fatalf("ReadString: %v", err)
+ }
+ if !strings.Contains(line, test.expectedResponse) {
+ t.Errorf("for test %#v got first line=%q", test, line)
+ }
+ }
+
+ for _, test := range serverExpectTests {
+ runTest(test)
+ }
+}
+
+// Under a ~256KB (maxPostHandlerReadBytes) threshold, the server
+// should consume client request bodies that a handler didn't read.
+func TestServerUnreadRequestBodyLittle(t *testing.T) {
+ conn := new(testConn)
+ body := strings.Repeat("x", 100<<10)
+ conn.readBuf.Write([]byte(fmt.Sprintf(
+ "POST / HTTP/1.1\r\n"+
+ "Host: test\r\n"+
+ "Content-Length: %d\r\n"+
+ "\r\n", len(body))))
+ conn.readBuf.Write([]byte(body))
+
+ done := make(chan bool)
+
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ defer close(done)
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len())
+ }
+ rw.WriteHeader(200)
+ if g, e := conn.readBuf.Len(), 0; g != e {
+ t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
+ }
+ if c := rw.Header().Get("Connection"); c != "" {
+ t.Errorf(`Connection header = %q; want ""`, c)
+ }
+ }))
+ <-done
+}
+
+// Over a ~256KB (maxPostHandlerReadBytes) threshold, the server
+// should ignore client request bodies that a handler didn't read
+// and close the connection.
+func TestServerUnreadRequestBodyLarge(t *testing.T) {
+ conn := new(testConn)
+ body := strings.Repeat("x", 1<<20)
+ conn.readBuf.Write([]byte(fmt.Sprintf(
+ "POST / HTTP/1.1\r\n"+
+ "Host: test\r\n"+
+ "Content-Length: %d\r\n"+
+ "\r\n", len(body))))
+ conn.readBuf.Write([]byte(body))
+
+ done := make(chan bool)
+
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ defer close(done)
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
+ }
+ rw.WriteHeader(200)
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
+ }
+ if c := rw.Header().Get("Connection"); c != "close" {
+ t.Errorf(`Connection header = %q; want "close"`, c)
+ }
+ }))
+ <-done
+}
+
+func TestTimeoutHandler(t *testing.T) {
+ sendHi := make(chan bool, 1)
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-sendHi
+ _, werr := w.Write([]byte("hi"))
+ writeErrors <- werr
+ })
+ timeout := make(chan time.Time, 1) // write to this to force timeouts
+ ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout))
+ defer ts.Close()
+
+ // Succeed without timing out:
+ sendHi <- true
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusOK; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(body), "hi"; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g := <-writeErrors; g != nil {
+ t.Errorf("got unexpected Write error on first request: %v", g)
+ }
+
+ // Times out:
+ timeout <- time.Time{}
+ res, err = Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ = ioutil.ReadAll(res.Body)
+ if !strings.Contains(string(body), "<title>Timeout</title>") {
+ t.Errorf("expected timeout body; got %q", string(body))
+ }
+
+ // Now make the previously-timed out handler speak again,
+ // which verifies the panic is handled:
+ sendHi <- true
+ if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
+ t.Errorf("expected Write error of %v; got %v", e, g)
+ }
+}
+
+// Verifies we don't path.Clean() on the wrong parts in redirects.
+func TestRedirectMunging(t *testing.T) {
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+
+ resp := httptest.NewRecorder()
+ Redirect(resp, req, "/foo?next=http://bar.com/", 302)
+ if g, e := resp.Header().Get("Location"), "/foo?next=http://bar.com/"; g != e {
+ t.Errorf("Location header was %q; want %q", g, e)
+ }
+
+ resp = httptest.NewRecorder()
+ Redirect(resp, req, "http://localhost:8080/_ah/login?continue=http://localhost:8080/", 302)
+ if g, e := resp.Header().Get("Location"), "http://localhost:8080/_ah/login?continue=http://localhost:8080/"; g != e {
+ t.Errorf("Location header was %q; want %q", g, e)
+ }
+}
+
+// TestZeroLengthPostAndResponse exercises an optimization done by the Transport:
+// when there is no body (either because the method doesn't permit a body, or an
+// explicit Content-Length of zero is present), then the transport can re-use the
+// connection immediately. But when it re-uses the connection, it typically closes
+// the previous request's body, which is not optimal for zero-lengthed bodies,
+// as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF.
+func TestZeroLengthPostAndResponse(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ all, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("handler ReadAll: %v", err)
+ }
+ if len(all) != 0 {
+ t.Errorf("handler got %d bytes; expected 0", len(all))
+ }
+ rw.Header().Set("Content-Length", "0")
+ }))
+ defer ts.Close()
+
+ req, err := NewRequest("POST", ts.URL, strings.NewReader(""))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = 0
+
+ var resp [5]*Response
+ for i := range resp {
+ resp[i], err = DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("client post #%d: %v", i, err)
+ }
+ }
+
+ for i := range resp {
+ all, err := ioutil.ReadAll(resp[i].Body)
+ if err != nil {
+ t.Fatalf("req #%d: client ReadAll: %v", i, err)
+ }
+ if len(all) != 0 {
+ t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
+ }
+ }
+}
+
+func TestHandlerPanic(t *testing.T) {
+ testHandlerPanic(t, false)
+}
+
+func TestHandlerPanicWithHijack(t *testing.T) {
+ testHandlerPanic(t, true)
+}
+
+func testHandlerPanic(t *testing.T, withHijack bool) {
+ // Unlike the other tests that set the log output to ioutil.Discard
+ // to quiet the output, this test uses a pipe. The pipe serves three
+ // purposes:
+ //
+ // 1) The log.Print from the http server (generated by the caught
+ // panic) will go to the pipe instead of stderr, making the
+ // output quiet.
+ //
+ // 2) We read from the pipe to verify that the handler
+ // actually caught the panic and logged something.
+ //
+ // 3) The blocking Read call prevents this TestHandlerPanic
+ // function from exiting before the HTTP server handler
+ // finishes crashing. If this text function exited too
+ // early (and its defer log.SetOutput(os.Stderr) ran),
+ // then the crash output could spill into the next test.
+ pr, pw := io.Pipe()
+ log.SetOutput(pw)
+ defer log.SetOutput(os.Stderr)
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if withHijack {
+ rwc, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Logf("unexpected error: %v", err)
+ }
+ defer rwc.Close()
+ }
+ panic("intentional death for testing")
+ }))
+ defer ts.Close()
+
+ // Do a blocking read on the log output pipe so its logging
+ // doesn't bleed into the next test. But wait only 5 seconds
+ // for it.
+ done := make(chan bool, 1)
+ go func() {
+ buf := make([]byte, 4<<10)
+ _, err := pr.Read(buf)
+ pr.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ done <- true
+ }()
+
+ _, err := Get(ts.URL)
+ if err == nil {
+ t.Logf("expected an error")
+ }
+
+ select {
+ case <-done:
+ return
+ case <-time.After(5 * time.Second):
+ t.Fatal("expected server handler to log an error")
+ }
+}
+
+func TestNoDate(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header()["Date"] = nil
+ }))
+ defer ts.Close()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, present := res.Header["Date"]
+ if present {
+ t.Fatalf("Expected no Date header; got %v", res.Header["Date"])
+ }
+}
+
+func TestStripPrefix(t *testing.T) {
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("X-Path", r.URL.Path)
+ })
+ ts := httptest.NewServer(StripPrefix("/foo", h))
+ defer ts.Close()
+
+ res, err := Get(ts.URL + "/foo/bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := res.Header.Get("X-Path"), "/bar"; g != e {
+ t.Errorf("test 1: got %s, want %s", g, e)
+ }
+
+ res, err = Get(ts.URL + "/bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := res.StatusCode, 404; g != e {
+ t.Errorf("test 2: got status %v, want %v", g, e)
+ }
+}
+
+func TestRequestLimit(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Fatalf("didn't expect to get request in Handler")
+ }))
+ defer ts.Close()
+ req, _ := NewRequest("GET", ts.URL, nil)
+ var bytesPerHeader = len("header12345: val12345\r\n")
+ for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
+ req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
+ }
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ // Some HTTP clients may fail on this undefined behavior (server replying and
+ // closing the connection while the request is still being written), but
+ // we do support it (at least currently), so we expect a response below.
+ t.Fatalf("Do: %v", err)
+ }
+ if res.StatusCode != 413 {
+ t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status)
+ }
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+type countReader struct {
+ r io.Reader
+ n *int64
+}
+
+func (cr countReader) Read(p []byte) (n int, err error) {
+ n, err = cr.r.Read(p)
+ *cr.n += int64(n)
+ return
+}
+
+func TestRequestBodyLimit(t *testing.T) {
+ const limit = 1 << 20
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ r.Body = MaxBytesReader(w, r.Body, limit)
+ n, err := io.Copy(ioutil.Discard, r.Body)
+ if err == nil {
+ t.Errorf("expected error from io.Copy")
+ }
+ if n != limit {
+ t.Errorf("io.Copy = %d, want %d", n, limit)
+ }
+ }))
+ defer ts.Close()
+
+ nWritten := int64(0)
+ req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), &nWritten}, limit*200))
+
+ // Send the POST, but don't care it succeeds or not. The
+ // remote side is going to reply and then close the TCP
+ // connection, and HTTP doesn't really define if that's
+ // allowed or not. Some HTTP clients will get the response
+ // and some (like ours, currently) will complain that the
+ // request write failed, without reading the response.
+ //
+ // But that's okay, since what we're really testing is that
+ // the remote side hung up on us before we wrote too much.
+ _, _ = DefaultClient.Do(req)
+
+ if nWritten > limit*100 {
+ t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
+ limit, nWritten)
+ }
+}
+
+// TestClientWriteShutdown tests that if the client shuts down the write
+// side of their TCP connection, the server doesn't send a 400 Bad Request.
+func TestClientWriteShutdown(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer ts.Close()
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ err = conn.(*net.TCPConn).CloseWrite()
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ donec := make(chan bool)
+ go func() {
+ defer close(donec)
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ t.Fatalf("ReadAll: %v", err)
+ }
+ got := string(bs)
+ if got != "" {
+ t.Errorf("read %q from server; want nothing", got)
+ }
+ }()
+ select {
+ case <-donec:
+ case <-time.After(10 * time.Second):
+ t.Fatalf("timeout")
+ }
+}
+
+// Tests that chunked server responses that write 1 byte at a time are
+// buffered before chunk headers are added, not after chunk headers.
+func TestServerBufferedChunking(t *testing.T) {
+ if true {
+ t.Logf("Skipping known broken test; see Issue 2357")
+ return
+ }
+ conn := new(testConn)
+ conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
+ done := make(chan bool)
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ defer close(done)
+ rw.Header().Set("Content-Type", "text/plain") // prevent sniffing, which buffers
+ rw.Write([]byte{'x'})
+ rw.Write([]byte{'y'})
+ rw.Write([]byte{'z'})
+ }))
+ <-done
+ if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
+ t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
+ conn.writeBuf.Bytes())
+ }
+}
+
+// goTimeout runs f, failing t if f takes more than ns to complete.
+func goTimeout(t *testing.T, d time.Duration, f func()) {
+ ch := make(chan bool, 2)
+ timer := time.AfterFunc(d, func() {
+ t.Errorf("Timeout expired after %v", d)
+ ch <- true
+ })
+ defer timer.Stop()
+ go func() {
+ defer func() { ch <- true }()
+ f()
+ }()
+ <-ch
+}
+
+type errorListener struct {
+ errs []error
+}
+
+func (l *errorListener) Accept() (c net.Conn, err error) {
+ if len(l.errs) == 0 {
+ return nil, io.EOF
+ }
+ err = l.errs[0]
+ l.errs = l.errs[1:]
+ return
+}
+
+func (l *errorListener) Close() error {
+ return nil
+}
+
+func (l *errorListener) Addr() net.Addr {
+ return dummyAddr("test-address")
+}
+
+func TestAcceptMaxFds(t *testing.T) {
+ log.SetOutput(ioutil.Discard) // is noisy otherwise
+ defer log.SetOutput(os.Stderr)
+
+ ln := &errorListener{[]error{
+ &net.OpError{
+ Op: "accept",
+ Err: syscall.EMFILE,
+ }}}
+ err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})))
+ if err != io.EOF {
+ t.Errorf("got error %v, want EOF", err)
+ }
+}
+
+func BenchmarkClientServer(b *testing.B) {
+ b.StopTimer()
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ fmt.Fprintf(rw, "Hello world.\n")
+ }))
+ defer ts.Close()
+ b.StartTimer()
+
+ for i := 0; i < b.N; i++ {
+ res, err := Get(ts.URL)
+ if err != nil {
+ b.Fatal("Get:", err)
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ b.Fatal("ReadAll:", err)
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ b.Fatal("Got body:", body)
+ }
+ }
+
+ b.StopTimer()
+}
diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go
new file mode 100644
index 000000000..bad3bcb28
--- /dev/null
+++ b/src/pkg/net/http/server.go
@@ -0,0 +1,1191 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP server. See RFC 2616.
+
+// TODO(rsc):
+// logging
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/rand"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/url"
+ "path"
+ "runtime/debug"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+// Errors introduced by the HTTP server.
+var (
+ ErrWriteAfterFlush = errors.New("Conn.Write called after Flush")
+ ErrBodyNotAllowed = errors.New("http: response status code does not allow body")
+ ErrHijacked = errors.New("Conn has been hijacked")
+ ErrContentLength = errors.New("Conn.Write wrote more than the declared Content-Length")
+)
+
+// Objects implementing the Handler interface can be
+// registered to serve a particular path or subtree
+// in the HTTP server.
+//
+// ServeHTTP should write reply headers and data to the ResponseWriter
+// and then return. Returning signals that the request is finished
+// and that the HTTP server can move on to the next request on
+// the connection.
+type Handler interface {
+ ServeHTTP(ResponseWriter, *Request)
+}
+
+// A ResponseWriter interface is used by an HTTP handler to
+// construct an HTTP response.
+type ResponseWriter interface {
+ // Header returns the header map that will be sent by WriteHeader.
+ // Changing the header after a call to WriteHeader (or Write) has
+ // no effect.
+ Header() Header
+
+ // Write writes the data to the connection as part of an HTTP reply.
+ // If WriteHeader has not yet been called, Write calls WriteHeader(http.StatusOK)
+ // before writing the data.
+ Write([]byte) (int, error)
+
+ // WriteHeader sends an HTTP response header with status code.
+ // If WriteHeader is not called explicitly, the first call to Write
+ // will trigger an implicit WriteHeader(http.StatusOK).
+ // Thus explicit calls to WriteHeader are mainly used to
+ // send error codes.
+ WriteHeader(int)
+}
+
+// The Flusher interface is implemented by ResponseWriters that allow
+// an HTTP handler to flush buffered data to the client.
+//
+// Note that even for ResponseWriters that support Flush,
+// if the client is connected through an HTTP proxy,
+// the buffered data may not reach the client until the response
+// completes.
+type Flusher interface {
+ // Flush sends any buffered data to the client.
+ Flush()
+}
+
+// The Hijacker interface is implemented by ResponseWriters that allow
+// an HTTP handler to take over the connection.
+type Hijacker interface {
+ // Hijack lets the caller take over the connection.
+ // After a call to Hijack(), the HTTP server library
+ // will not do anything else with the connection.
+ // It becomes the caller's responsibility to manage
+ // and close the connection.
+ Hijack() (net.Conn, *bufio.ReadWriter, error)
+}
+
+// A conn represents the server side of an HTTP connection.
+type conn struct {
+ remoteAddr string // network address of remote side
+ server *Server // the Server on which the connection arrived
+ rwc net.Conn // i/o connection
+ lr *io.LimitedReader // io.LimitReader(rwc)
+ buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc
+ hijacked bool // connection has been hijacked by handler
+ tlsState *tls.ConnectionState // or nil when not using TLS
+ body []byte
+}
+
+// A response represents the server side of an HTTP response.
+type response struct {
+ conn *conn
+ req *Request // request for this response
+ chunking bool // using chunked transfer encoding for reply body
+ wroteHeader bool // reply header has been written
+ wroteContinue bool // 100 Continue response was written
+ header Header // reply header parameters
+ written int64 // number of bytes written in body
+ contentLength int64 // explicitly-declared Content-Length; or -1
+ status int // status code passed to WriteHeader
+ needSniff bool // need to sniff to find Content-Type
+
+ // close connection after this reply. set on request and
+ // updated after response from handler if there's a
+ // "Connection: keep-alive" response header and a
+ // Content-Length.
+ closeAfterReply bool
+
+ // requestBodyLimitHit is set by requestTooLarge when
+ // maxBytesReader hits its max size. It is checked in
+ // WriteHeader, to make sure we don't consume the the
+ // remaining request body to try to advance to the next HTTP
+ // request. Instead, when this is set, we stop doing
+ // subsequent requests on this connection and stop reading
+ // input from it.
+ requestBodyLimitHit bool
+}
+
+// requestTooLarge is called by maxBytesReader when too much input has
+// been read from the client.
+func (w *response) requestTooLarge() {
+ w.closeAfterReply = true
+ w.requestBodyLimitHit = true
+ if !w.wroteHeader {
+ w.Header().Set("Connection", "close")
+ }
+}
+
+type writerOnly struct {
+ io.Writer
+}
+
+func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
+ // Call WriteHeader before checking w.chunking if it hasn't
+ // been called yet, since WriteHeader is what sets w.chunking.
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+ if !w.chunking && w.bodyAllowed() && !w.needSniff {
+ w.Flush()
+ if rf, ok := w.conn.rwc.(io.ReaderFrom); ok {
+ n, err = rf.ReadFrom(src)
+ w.written += n
+ return
+ }
+ }
+ // Fall back to default io.Copy implementation.
+ // Use wrapper to hide w.ReadFrom from io.Copy.
+ return io.Copy(writerOnly{w}, src)
+}
+
+// noLimit is an effective infinite upper bound for io.LimitedReader
+const noLimit int64 = (1 << 63) - 1
+
+// Create new connection from rwc.
+func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
+ c = new(conn)
+ c.remoteAddr = rwc.RemoteAddr().String()
+ c.server = srv
+ c.rwc = rwc
+ c.body = make([]byte, sniffLen)
+ c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader)
+ br := bufio.NewReader(c.lr)
+ bw := bufio.NewWriter(rwc)
+ c.buf = bufio.NewReadWriter(br, bw)
+ return c, nil
+}
+
+// DefaultMaxHeaderBytes is the maximum permitted size of the headers
+// in an HTTP request.
+// This can be overridden by setting Server.MaxHeaderBytes.
+const DefaultMaxHeaderBytes = 1 << 20 // 1 MB
+
+func (srv *Server) maxHeaderBytes() int {
+ if srv.MaxHeaderBytes > 0 {
+ return srv.MaxHeaderBytes
+ }
+ return DefaultMaxHeaderBytes
+}
+
+// wrapper around io.ReaderCloser which on first read, sends an
+// HTTP/1.1 100 Continue header
+type expectContinueReader struct {
+ resp *response
+ readCloser io.ReadCloser
+ closed bool
+}
+
+func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
+ if ecr.closed {
+ return 0, errors.New("http: Read after Close on request Body")
+ }
+ if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked {
+ ecr.resp.wroteContinue = true
+ io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n")
+ ecr.resp.conn.buf.Flush()
+ }
+ return ecr.readCloser.Read(p)
+}
+
+func (ecr *expectContinueReader) Close() error {
+ ecr.closed = true
+ return ecr.readCloser.Close()
+}
+
+// TimeFormat is the time format to use with
+// time.Parse and time.Time.Format when parsing
+// or generating times in HTTP headers.
+// It is like time.RFC1123 but hard codes GMT as the time zone.
+const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
+
+var errTooLarge = errors.New("http: request too large")
+
+// Read next request from connection.
+func (c *conn) readRequest() (w *response, err error) {
+ if c.hijacked {
+ return nil, ErrHijacked
+ }
+ c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
+ var req *Request
+ if req, err = ReadRequest(c.buf.Reader); err != nil {
+ if c.lr.N == 0 {
+ return nil, errTooLarge
+ }
+ return nil, err
+ }
+ c.lr.N = noLimit
+
+ req.RemoteAddr = c.remoteAddr
+ req.TLS = c.tlsState
+
+ w = new(response)
+ w.conn = c
+ w.req = req
+ w.header = make(Header)
+ w.contentLength = -1
+ c.body = c.body[:0]
+ return w, nil
+}
+
+func (w *response) Header() Header {
+ return w.header
+}
+
+// maxPostHandlerReadBytes is the max number of Request.Body bytes not
+// consumed by a handler that the server will read from the client
+// in order to keep a connection alive. If there are more bytes than
+// this then the server to be paranoid instead sends a "Connection:
+// close" response.
+//
+// This number is approximately what a typical machine's TCP buffer
+// size is anyway. (if we have the bytes on the machine, we might as
+// well read them)
+const maxPostHandlerReadBytes = 256 << 10
+
+func (w *response) WriteHeader(code int) {
+ if w.conn.hijacked {
+ log.Print("http: response.WriteHeader on hijacked connection")
+ return
+ }
+ if w.wroteHeader {
+ log.Print("http: multiple response.WriteHeader calls")
+ return
+ }
+ w.wroteHeader = true
+ w.status = code
+
+ // Check for a explicit (and valid) Content-Length header.
+ var hasCL bool
+ var contentLength int64
+ if clenStr := w.header.Get("Content-Length"); clenStr != "" {
+ var err error
+ contentLength, err = strconv.ParseInt(clenStr, 10, 64)
+ if err == nil {
+ hasCL = true
+ } else {
+ log.Printf("http: invalid Content-Length of %q sent", clenStr)
+ w.header.Del("Content-Length")
+ }
+ }
+
+ if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) {
+ _, connectionHeaderSet := w.header["Connection"]
+ if !connectionHeaderSet {
+ w.header.Set("Connection", "keep-alive")
+ }
+ } else if !w.req.ProtoAtLeast(1, 1) {
+ // Client did not ask to keep connection alive.
+ w.closeAfterReply = true
+ }
+
+ if w.header.Get("Connection") == "close" {
+ w.closeAfterReply = true
+ }
+
+ // Per RFC 2616, we should consume the request body before
+ // replying, if the handler hasn't already done so. But we
+ // don't want to do an unbounded amount of reading here for
+ // DoS reasons, so we only try up to a threshold.
+ if w.req.ContentLength != 0 && !w.closeAfterReply {
+ ecr, isExpecter := w.req.Body.(*expectContinueReader)
+ if !isExpecter || ecr.resp.wroteContinue {
+ n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1)
+ if n >= maxPostHandlerReadBytes {
+ w.requestTooLarge()
+ w.header.Set("Connection", "close")
+ } else {
+ w.req.Body.Close()
+ }
+ }
+ }
+
+ if code == StatusNotModified {
+ // Must not have body.
+ for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} {
+ if w.header.Get(header) != "" {
+ // TODO: return an error if WriteHeader gets a return parameter
+ // or set a flag on w to make future Writes() write an error page?
+ // for now just log and drop the header.
+ log.Printf("http: StatusNotModified response with header %q defined", header)
+ w.header.Del(header)
+ }
+ }
+ } else {
+ // If no content type, apply sniffing algorithm to body.
+ if w.header.Get("Content-Type") == "" {
+ w.needSniff = true
+ }
+ }
+
+ if _, ok := w.header["Date"]; !ok {
+ w.Header().Set("Date", time.Now().UTC().Format(TimeFormat))
+ }
+
+ te := w.header.Get("Transfer-Encoding")
+ hasTE := te != ""
+ if hasCL && hasTE && te != "identity" {
+ // TODO: return an error if WriteHeader gets a return parameter
+ // For now just ignore the Content-Length.
+ log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
+ te, contentLength)
+ w.header.Del("Content-Length")
+ hasCL = false
+ }
+
+ if w.req.Method == "HEAD" || code == StatusNotModified {
+ // do nothing
+ } else if hasCL {
+ w.contentLength = contentLength
+ w.header.Del("Transfer-Encoding")
+ } else if w.req.ProtoAtLeast(1, 1) {
+ // HTTP/1.1 or greater: use chunked transfer encoding
+ // to avoid closing the connection at EOF.
+ // TODO: this blows away any custom or stacked Transfer-Encoding they
+ // might have set. Deal with that as need arises once we have a valid
+ // use case.
+ w.chunking = true
+ w.header.Set("Transfer-Encoding", "chunked")
+ } else {
+ // HTTP version < 1.1: cannot do chunked transfer
+ // encoding and we don't know the Content-Length so
+ // signal EOF by closing connection.
+ w.closeAfterReply = true
+ w.header.Del("Transfer-Encoding") // in case already set
+ }
+
+ // Cannot use Content-Length with non-identity Transfer-Encoding.
+ if w.chunking {
+ w.header.Del("Content-Length")
+ }
+ if !w.req.ProtoAtLeast(1, 0) {
+ return
+ }
+ proto := "HTTP/1.0"
+ if w.req.ProtoAtLeast(1, 1) {
+ proto = "HTTP/1.1"
+ }
+ codestring := strconv.Itoa(code)
+ text, ok := statusText[code]
+ if !ok {
+ text = "status code " + codestring
+ }
+ io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n")
+ w.header.Write(w.conn.buf)
+
+ // If we need to sniff the body, leave the header open.
+ // Otherwise, end it here.
+ if !w.needSniff {
+ io.WriteString(w.conn.buf, "\r\n")
+ }
+}
+
+// sniff uses the first block of written data,
+// stored in w.conn.body, to decide the Content-Type
+// for the HTTP body.
+func (w *response) sniff() {
+ if !w.needSniff {
+ return
+ }
+ w.needSniff = false
+
+ data := w.conn.body
+ fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n\r\n", DetectContentType(data))
+
+ if len(data) == 0 {
+ return
+ }
+ if w.chunking {
+ fmt.Fprintf(w.conn.buf, "%x\r\n", len(data))
+ }
+ _, err := w.conn.buf.Write(data)
+ if w.chunking && err == nil {
+ io.WriteString(w.conn.buf, "\r\n")
+ }
+}
+
+// bodyAllowed returns true if a Write is allowed for this response type.
+// It's illegal to call this before the header has been flushed.
+func (w *response) bodyAllowed() bool {
+ if !w.wroteHeader {
+ panic("")
+ }
+ return w.status != StatusNotModified && w.req.Method != "HEAD"
+}
+
+func (w *response) Write(data []byte) (n int, err error) {
+ if w.conn.hijacked {
+ log.Print("http: response.Write on hijacked connection")
+ return 0, ErrHijacked
+ }
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+ if len(data) == 0 {
+ return 0, nil
+ }
+ if !w.bodyAllowed() {
+ return 0, ErrBodyNotAllowed
+ }
+
+ w.written += int64(len(data)) // ignoring errors, for errorKludge
+ if w.contentLength != -1 && w.written > w.contentLength {
+ return 0, ErrContentLength
+ }
+
+ var m int
+ if w.needSniff {
+ // We need to sniff the beginning of the output to
+ // determine the content type. Accumulate the
+ // initial writes in w.conn.body.
+ // Cap m so that append won't allocate.
+ m = cap(w.conn.body) - len(w.conn.body)
+ if m > len(data) {
+ m = len(data)
+ }
+ w.conn.body = append(w.conn.body, data[:m]...)
+ data = data[m:]
+ if len(data) == 0 {
+ // Copied everything into the buffer.
+ // Wait for next write.
+ return m, nil
+ }
+
+ // Filled the buffer; more data remains.
+ // Sniff the content (flushes the buffer)
+ // and then proceed with the remainder
+ // of the data as a normal Write.
+ // Calling sniff clears needSniff.
+ w.sniff()
+ }
+
+ // TODO(rsc): if chunking happened after the buffering,
+ // then there would be fewer chunk headers.
+ // On the other hand, it would make hijacking more difficult.
+ if w.chunking {
+ fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt
+ }
+ n, err = w.conn.buf.Write(data)
+ if err == nil && w.chunking {
+ if n != len(data) {
+ err = io.ErrShortWrite
+ }
+ if err == nil {
+ io.WriteString(w.conn.buf, "\r\n")
+ }
+ }
+
+ return m + n, err
+}
+
+func (w *response) finishRequest() {
+ // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length
+ // back, we can make this a keep-alive response ...
+ if w.req.wantsHttp10KeepAlive() {
+ sentLength := w.header.Get("Content-Length") != ""
+ if sentLength && w.header.Get("Connection") == "keep-alive" {
+ w.closeAfterReply = false
+ }
+ }
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+ if w.needSniff {
+ w.sniff()
+ }
+ if w.chunking {
+ io.WriteString(w.conn.buf, "0\r\n")
+ // trailer key/value pairs, followed by blank line
+ io.WriteString(w.conn.buf, "\r\n")
+ }
+ w.conn.buf.Flush()
+ // Close the body, unless we're about to close the whole TCP connection
+ // anyway.
+ if !w.closeAfterReply {
+ w.req.Body.Close()
+ }
+ if w.req.MultipartForm != nil {
+ w.req.MultipartForm.RemoveAll()
+ }
+
+ if w.contentLength != -1 && w.contentLength != w.written {
+ // Did not write enough. Avoid getting out of sync.
+ w.closeAfterReply = true
+ }
+}
+
+func (w *response) Flush() {
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+ w.sniff()
+ w.conn.buf.Flush()
+}
+
+// Close the connection.
+func (c *conn) close() {
+ if c.buf != nil {
+ c.buf.Flush()
+ c.buf = nil
+ }
+ if c.rwc != nil {
+ c.rwc.Close()
+ c.rwc = nil
+ }
+}
+
+// Serve a new connection.
+func (c *conn) serve() {
+ defer func() {
+ err := recover()
+ if err == nil {
+ return
+ }
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err)
+ buf.Write(debug.Stack())
+ log.Print(buf.String())
+
+ if c.rwc != nil { // may be nil if connection hijacked
+ c.rwc.Close()
+ }
+ }()
+
+ if tlsConn, ok := c.rwc.(*tls.Conn); ok {
+ if err := tlsConn.Handshake(); err != nil {
+ c.close()
+ return
+ }
+ c.tlsState = new(tls.ConnectionState)
+ *c.tlsState = tlsConn.ConnectionState()
+ }
+
+ for {
+ w, err := c.readRequest()
+ if err != nil {
+ msg := "400 Bad Request"
+ if err == errTooLarge {
+ // Their HTTP client may or may not be
+ // able to read this if we're
+ // responding to them and hanging up
+ // while they're still writing their
+ // request. Undefined behavior.
+ msg = "413 Request Entity Too Large"
+ } else if err == io.ErrUnexpectedEOF {
+ break // Don't reply
+ } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
+ break // Don't reply
+ }
+ fmt.Fprintf(c.rwc, "HTTP/1.1 %s\r\n\r\n", msg)
+ break
+ }
+
+ // Expect 100 Continue support
+ req := w.req
+ if req.expectsContinue() {
+ if req.ProtoAtLeast(1, 1) {
+ // Wrap the Body reader with one that replies on the connection
+ req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
+ }
+ if req.ContentLength == 0 {
+ w.Header().Set("Connection", "close")
+ w.WriteHeader(StatusBadRequest)
+ w.finishRequest()
+ break
+ }
+ req.Header.Del("Expect")
+ } else if req.Header.Get("Expect") != "" {
+ // TODO(bradfitz): let ServeHTTP handlers handle
+ // requests with non-standard expectation[s]? Seems
+ // theoretical at best, and doesn't fit into the
+ // current ServeHTTP model anyway. We'd need to
+ // make the ResponseWriter an optional
+ // "ExpectReplier" interface or something.
+ //
+ // For now we'll just obey RFC 2616 14.20 which says
+ // "If a server receives a request containing an
+ // Expect field that includes an expectation-
+ // extension that it does not support, it MUST
+ // respond with a 417 (Expectation Failed) status."
+ w.Header().Set("Connection", "close")
+ w.WriteHeader(StatusExpectationFailed)
+ w.finishRequest()
+ break
+ }
+
+ handler := c.server.Handler
+ if handler == nil {
+ handler = DefaultServeMux
+ }
+
+ // HTTP cannot have multiple simultaneous active requests.[*]
+ // Until the server replies to this request, it can't read another,
+ // so we might as well run the handler in this goroutine.
+ // [*] Not strictly true: HTTP pipelining. We could let them all process
+ // in parallel even if their responses need to be serialized.
+ handler.ServeHTTP(w, w.req)
+ if c.hijacked {
+ return
+ }
+ w.finishRequest()
+ if w.closeAfterReply {
+ break
+ }
+ }
+ c.close()
+}
+
+// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
+// and a Hijacker.
+func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
+ if w.conn.hijacked {
+ return nil, nil, ErrHijacked
+ }
+ w.conn.hijacked = true
+ rwc = w.conn.rwc
+ buf = w.conn.buf
+ w.conn.rwc = nil
+ w.conn.buf = nil
+ return
+}
+
+// The HandlerFunc type is an adapter to allow the use of
+// ordinary functions as HTTP handlers. If f is a function
+// with the appropriate signature, HandlerFunc(f) is a
+// Handler object that calls f.
+type HandlerFunc func(ResponseWriter, *Request)
+
+// ServeHTTP calls f(w, r).
+func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) {
+ f(w, r)
+}
+
+// Helper handlers
+
+// Error replies to the request with the specified error message and HTTP code.
+func Error(w ResponseWriter, error string, code int) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(code)
+ fmt.Fprintln(w, error)
+}
+
+// NotFound replies to the request with an HTTP 404 not found error.
+func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", StatusNotFound) }
+
+// NotFoundHandler returns a simple request handler
+// that replies to each request with a ``404 page not found'' reply.
+func NotFoundHandler() Handler { return HandlerFunc(NotFound) }
+
+// StripPrefix returns a handler that serves HTTP requests
+// by removing the given prefix from the request URL's Path
+// and invoking the handler h. StripPrefix handles a
+// request for a path that doesn't begin with prefix by
+// replying with an HTTP 404 not found error.
+func StripPrefix(prefix string, h Handler) Handler {
+ return HandlerFunc(func(w ResponseWriter, r *Request) {
+ if !strings.HasPrefix(r.URL.Path, prefix) {
+ NotFound(w, r)
+ return
+ }
+ r.URL.Path = r.URL.Path[len(prefix):]
+ h.ServeHTTP(w, r)
+ })
+}
+
+// Redirect replies to the request with a redirect to url,
+// which may be a path relative to the request path.
+func Redirect(w ResponseWriter, r *Request, urlStr string, code int) {
+ if u, err := url.Parse(urlStr); err == nil {
+ // If url was relative, make absolute by
+ // combining with request path.
+ // The browser would probably do this for us,
+ // but doing it ourselves is more reliable.
+
+ // NOTE(rsc): RFC 2616 says that the Location
+ // line must be an absolute URI, like
+ // "http://www.google.com/redirect/",
+ // not a path like "/redirect/".
+ // Unfortunately, we don't know what to
+ // put in the host name section to get the
+ // client to connect to us again, so we can't
+ // know the right absolute URI to send back.
+ // Because of this problem, no one pays attention
+ // to the RFC; they all send back just a new path.
+ // So do we.
+ oldpath := r.URL.Path
+ if oldpath == "" { // should not happen, but avoid a crash if it does
+ oldpath = "/"
+ }
+ if u.Scheme == "" {
+ // no leading http://server
+ if urlStr == "" || urlStr[0] != '/' {
+ // make relative path absolute
+ olddir, _ := path.Split(oldpath)
+ urlStr = olddir + urlStr
+ }
+
+ var query string
+ if i := strings.Index(urlStr, "?"); i != -1 {
+ urlStr, query = urlStr[:i], urlStr[i:]
+ }
+
+ // clean up but preserve trailing slash
+ trailing := urlStr[len(urlStr)-1] == '/'
+ urlStr = path.Clean(urlStr)
+ if trailing && urlStr[len(urlStr)-1] != '/' {
+ urlStr += "/"
+ }
+ urlStr += query
+ }
+ }
+
+ w.Header().Set("Location", urlStr)
+ w.WriteHeader(code)
+
+ // RFC2616 recommends that a short note "SHOULD" be included in the
+ // response because older user agents may not understand 301/307.
+ // Shouldn't send the response for POST or HEAD; that leaves GET.
+ if r.Method == "GET" {
+ note := "<a href=\"" + htmlEscape(urlStr) + "\">" + statusText[code] + "</a>.\n"
+ fmt.Fprintln(w, note)
+ }
+}
+
+var htmlReplacer = strings.NewReplacer(
+ "&", "&amp;",
+ "<", "&lt;",
+ ">", "&gt;",
+ `"`, "&quot;",
+ "'", "&apos;",
+)
+
+func htmlEscape(s string) string {
+ return htmlReplacer.Replace(s)
+}
+
+// Redirect to a fixed URL
+type redirectHandler struct {
+ url string
+ code int
+}
+
+func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ Redirect(w, r, rh.url, rh.code)
+}
+
+// RedirectHandler returns a request handler that redirects
+// each request it receives to the given url using the given
+// status code.
+func RedirectHandler(url string, code int) Handler {
+ return &redirectHandler{url, code}
+}
+
+// ServeMux is an HTTP request multiplexer.
+// It matches the URL of each incoming request against a list of registered
+// patterns and calls the handler for the pattern that
+// most closely matches the URL.
+//
+// Patterns named fixed, rooted paths, like "/favicon.ico",
+// or rooted subtrees, like "/images/" (note the trailing slash).
+// Longer patterns take precedence over shorter ones, so that
+// if there are handlers registered for both "/images/"
+// and "/images/thumbnails/", the latter handler will be
+// called for paths beginning "/images/thumbnails/" and the
+// former will receiver requests for any other paths in the
+// "/images/" subtree.
+//
+// Patterns may optionally begin with a host name, restricting matches to
+// URLs on that host only. Host-specific patterns take precedence over
+// general patterns, so that a handler might register for the two patterns
+// "/codesearch" and "codesearch.google.com/" without also taking over
+// requests for "http://www.google.com/".
+//
+// ServeMux also takes care of sanitizing the URL request path,
+// redirecting any request containing . or .. elements to an
+// equivalent .- and ..-free URL.
+type ServeMux struct {
+ m map[string]Handler
+}
+
+// NewServeMux allocates and returns a new ServeMux.
+func NewServeMux() *ServeMux { return &ServeMux{make(map[string]Handler)} }
+
+// DefaultServeMux is the default ServeMux used by Serve.
+var DefaultServeMux = NewServeMux()
+
+// Does path match pattern?
+func pathMatch(pattern, path string) bool {
+ if len(pattern) == 0 {
+ // should not happen
+ return false
+ }
+ n := len(pattern)
+ if pattern[n-1] != '/' {
+ return pattern == path
+ }
+ return len(path) >= n && path[0:n] == pattern
+}
+
+// Return the canonical path for p, eliminating . and .. elements.
+func cleanPath(p string) string {
+ if p == "" {
+ return "/"
+ }
+ if p[0] != '/' {
+ p = "/" + p
+ }
+ np := path.Clean(p)
+ // path.Clean removes trailing slash except for root;
+ // put the trailing slash back if necessary.
+ if p[len(p)-1] == '/' && np != "/" {
+ np += "/"
+ }
+ return np
+}
+
+// Find a handler on a handler map given a path string
+// Most-specific (longest) pattern wins
+func (mux *ServeMux) match(path string) Handler {
+ var h Handler
+ var n = 0
+ for k, v := range mux.m {
+ if !pathMatch(k, path) {
+ continue
+ }
+ if h == nil || len(k) > n {
+ n = len(k)
+ h = v
+ }
+ }
+ return h
+}
+
+// ServeHTTP dispatches the request to the handler whose
+// pattern most closely matches the request URL.
+func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
+ // Clean path to canonical form and redirect.
+ if p := cleanPath(r.URL.Path); p != r.URL.Path {
+ w.Header().Set("Location", p)
+ w.WriteHeader(StatusMovedPermanently)
+ return
+ }
+ // Host-specific pattern takes precedence over generic ones
+ h := mux.match(r.Host + r.URL.Path)
+ if h == nil {
+ h = mux.match(r.URL.Path)
+ }
+ if h == nil {
+ h = NotFoundHandler()
+ }
+ h.ServeHTTP(w, r)
+}
+
+// Handle registers the handler for the given pattern.
+func (mux *ServeMux) Handle(pattern string, handler Handler) {
+ if pattern == "" {
+ panic("http: invalid pattern " + pattern)
+ }
+
+ mux.m[pattern] = handler
+
+ // Helpful behavior:
+ // If pattern is /tree/, insert permanent redirect for /tree.
+ n := len(pattern)
+ if n > 0 && pattern[n-1] == '/' {
+ mux.m[pattern[0:n-1]] = RedirectHandler(pattern, StatusMovedPermanently)
+ }
+}
+
+// HandleFunc registers the handler function for the given pattern.
+func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ mux.Handle(pattern, HandlerFunc(handler))
+}
+
+// Handle registers the handler for the given pattern
+// in the DefaultServeMux.
+// The documentation for ServeMux explains how patterns are matched.
+func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
+
+// HandleFunc registers the handler function for the given pattern
+// in the DefaultServeMux.
+// The documentation for ServeMux explains how patterns are matched.
+func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ DefaultServeMux.HandleFunc(pattern, handler)
+}
+
+// Serve accepts incoming HTTP connections on the listener l,
+// creating a new service thread for each. The service threads
+// read requests and then call handler to reply to them.
+// Handler is typically nil, in which case the DefaultServeMux is used.
+func Serve(l net.Listener, handler Handler) error {
+ srv := &Server{Handler: handler}
+ return srv.Serve(l)
+}
+
+// A Server defines parameters for running an HTTP server.
+type Server struct {
+ Addr string // TCP address to listen on, ":http" if empty
+ Handler Handler // handler to invoke, http.DefaultServeMux if nil
+ ReadTimeout time.Duration // maximum duration before timing out read of the request
+ WriteTimeout time.Duration // maximum duration before timing out write of the response
+ MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0
+}
+
+// ListenAndServe listens on the TCP network address srv.Addr and then
+// calls Serve to handle requests on incoming connections. If
+// srv.Addr is blank, ":http" is used.
+func (srv *Server) ListenAndServe() error {
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":http"
+ }
+ l, e := net.Listen("tcp", addr)
+ if e != nil {
+ return e
+ }
+ return srv.Serve(l)
+}
+
+// Serve accepts incoming connections on the Listener l, creating a
+// new service thread for each. The service threads read requests and
+// then call srv.Handler to reply to them.
+func (srv *Server) Serve(l net.Listener) error {
+ defer l.Close()
+ for {
+ rw, e := l.Accept()
+ if e != nil {
+ if ne, ok := e.(net.Error); ok && ne.Temporary() {
+ log.Printf("http: Accept error: %v", e)
+ continue
+ }
+ return e
+ }
+ if srv.ReadTimeout != 0 {
+ rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
+ }
+ if srv.WriteTimeout != 0 {
+ rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
+ }
+ c, err := srv.newConn(rw)
+ if err != nil {
+ continue
+ }
+ go c.serve()
+ }
+ panic("not reached")
+}
+
+// ListenAndServe listens on the TCP network address addr
+// and then calls Serve with handler to handle requests
+// on incoming connections. Handler is typically nil,
+// in which case the DefaultServeMux is used.
+//
+// A trivial example server is:
+//
+// package main
+//
+// import (
+// "io"
+// "net/http"
+// "log"
+// )
+//
+// // hello world, the web server
+// func HelloServer(w http.ResponseWriter, req *http.Request) {
+// io.WriteString(w, "hello, world!\n")
+// }
+//
+// func main() {
+// http.HandleFunc("/hello", HelloServer)
+// err := http.ListenAndServe(":12345", nil)
+// if err != nil {
+// log.Fatal("ListenAndServe: ", err)
+// }
+// }
+func ListenAndServe(addr string, handler Handler) error {
+ server := &Server{Addr: addr, Handler: handler}
+ return server.ListenAndServe()
+}
+
+// ListenAndServeTLS acts identically to ListenAndServe, except that it
+// expects HTTPS connections. Additionally, files containing a certificate and
+// matching private key for the server must be provided. If the certificate
+// is signed by a certificate authority, the certFile should be the concatenation
+// of the server's certificate followed by the CA's certificate.
+//
+// A trivial example server is:
+//
+// import (
+// "log"
+// "net/http"
+// )
+//
+// func handler(w http.ResponseWriter, req *http.Request) {
+// w.Header().Set("Content-Type", "text/plain")
+// w.Write([]byte("This is an example server.\n"))
+// }
+//
+// func main() {
+// http.HandleFunc("/", handler)
+// log.Printf("About to listen on 10443. Go to https://127.0.0.1:10443/")
+// err := http.ListenAndServeTLS(":10443", "cert.pem", "key.pem", nil)
+// if err != nil {
+// log.Fatal(err)
+// }
+// }
+//
+// One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem.
+func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler) error {
+ server := &Server{Addr: addr, Handler: handler}
+ return server.ListenAndServeTLS(certFile, keyFile)
+}
+
+// ListenAndServeTLS listens on the TCP network address srv.Addr and
+// then calls Serve to handle requests on incoming TLS connections.
+//
+// Filenames containing a certificate and matching private key for
+// the server must be provided. If the certificate is signed by a
+// certificate authority, the certFile should be the concatenation
+// of the server's certificate followed by the CA's certificate.
+//
+// If srv.Addr is blank, ":https" is used.
+func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
+ addr := s.Addr
+ if addr == "" {
+ addr = ":https"
+ }
+ config := &tls.Config{
+ Rand: rand.Reader,
+ NextProtos: []string{"http/1.1"},
+ }
+
+ var err error
+ config.Certificates = make([]tls.Certificate, 1)
+ config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return err
+ }
+
+ conn, err := net.Listen("tcp", addr)
+ if err != nil {
+ return err
+ }
+
+ tlsListener := tls.NewListener(conn, config)
+ return s.Serve(tlsListener)
+}
+
+// TimeoutHandler returns a Handler that runs h with the given time limit.
+//
+// The new Handler calls h.ServeHTTP to handle each request, but if a
+// call runs for more than ns nanoseconds, the handler responds with
+// a 503 Service Unavailable error and the given message in its body.
+// (If msg is empty, a suitable default message will be sent.)
+// After such a timeout, writes by h to its ResponseWriter will return
+// ErrHandlerTimeout.
+func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler {
+ f := func() <-chan time.Time {
+ return time.After(dt)
+ }
+ return &timeoutHandler{h, f, msg}
+}
+
+// ErrHandlerTimeout is returned on ResponseWriter Write calls
+// in handlers which have timed out.
+var ErrHandlerTimeout = errors.New("http: Handler timeout")
+
+type timeoutHandler struct {
+ handler Handler
+ timeout func() <-chan time.Time // returns channel producing a timeout
+ body string
+}
+
+func (h *timeoutHandler) errorBody() string {
+ if h.body != "" {
+ return h.body
+ }
+ return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>"
+}
+
+func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ done := make(chan bool)
+ tw := &timeoutWriter{w: w}
+ go func() {
+ h.handler.ServeHTTP(tw, r)
+ done <- true
+ }()
+ select {
+ case <-done:
+ return
+ case <-h.timeout():
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ if !tw.wroteHeader {
+ tw.w.WriteHeader(StatusServiceUnavailable)
+ tw.w.Write([]byte(h.errorBody()))
+ }
+ tw.timedOut = true
+ }
+}
+
+type timeoutWriter struct {
+ w ResponseWriter
+
+ mu sync.Mutex
+ timedOut bool
+ wroteHeader bool
+}
+
+func (tw *timeoutWriter) Header() Header {
+ return tw.w.Header()
+}
+
+func (tw *timeoutWriter) Write(p []byte) (int, error) {
+ tw.mu.Lock()
+ timedOut := tw.timedOut
+ tw.mu.Unlock()
+ if timedOut {
+ return 0, ErrHandlerTimeout
+ }
+ return tw.w.Write(p)
+}
+
+func (tw *timeoutWriter) WriteHeader(code int) {
+ tw.mu.Lock()
+ if tw.timedOut || tw.wroteHeader {
+ tw.mu.Unlock()
+ return
+ }
+ tw.wroteHeader = true
+ tw.mu.Unlock()
+ tw.w.WriteHeader(code)
+}
diff --git a/src/pkg/net/http/sniff.go b/src/pkg/net/http/sniff.go
new file mode 100644
index 000000000..c1c78e241
--- /dev/null
+++ b/src/pkg/net/http/sniff.go
@@ -0,0 +1,214 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "encoding/binary"
+)
+
+// Content-type sniffing algorithm.
+// References in this file refer to this draft specification:
+// http://mimesniff.spec.whatwg.org/
+
+// The algorithm prefers to use sniffLen bytes to make its decision.
+const sniffLen = 512
+
+// DetectContentType returns the sniffed Content-Type string
+// for the given data. This function always returns a valid MIME type.
+func DetectContentType(data []byte) string {
+ if len(data) > sniffLen {
+ data = data[:sniffLen]
+ }
+
+ // Index of the first non-whitespace byte in data.
+ firstNonWS := 0
+ for ; firstNonWS < len(data) && isWS(data[firstNonWS]); firstNonWS++ {
+ }
+
+ for _, sig := range sniffSignatures {
+ if ct := sig.match(data, firstNonWS); ct != "" {
+ return ct
+ }
+ }
+
+ return "application/octet-stream" // fallback
+}
+
+func isWS(b byte) bool {
+ return bytes.IndexByte([]byte("\t\n\x0C\r "), b) != -1
+}
+
+type sniffSig interface {
+ // match returns the MIME type of the data, or "" if unknown.
+ match(data []byte, firstNonWS int) string
+}
+
+// Data matching the table in section 6.
+var sniffSignatures = []sniffSig{
+ htmlSig("<!DOCTYPE HTML"),
+ htmlSig("<HTML"),
+ htmlSig("<HEAD"),
+ htmlSig("<SCRIPT"),
+ htmlSig("<IFRAME"),
+ htmlSig("<H1"),
+ htmlSig("<DIV"),
+ htmlSig("<FONT"),
+ htmlSig("<TABLE"),
+ htmlSig("<A"),
+ htmlSig("<STYLE"),
+ htmlSig("<TITLE"),
+ htmlSig("<B"),
+ htmlSig("<BODY"),
+ htmlSig("<BR"),
+ htmlSig("<P"),
+ htmlSig("<!--"),
+
+ &maskedSig{mask: []byte("\xFF\xFF\xFF\xFF\xFF"), pat: []byte("<?xml"), skipWS: true, ct: "text/xml; charset=utf-8"},
+
+ &exactSig{[]byte("%PDF-"), "application/pdf"},
+ &exactSig{[]byte("%!PS-Adobe-"), "application/postscript"},
+
+ // UTF BOMs.
+ &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFE\xFF\x00\x00"), ct: "text/plain; charset=utf-16be"},
+ &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFF\xFE\x00\x00"), ct: "text/plain; charset=utf-16le"},
+ &maskedSig{mask: []byte("\xFF\xFF\xFF\x00"), pat: []byte("\xEF\xBB\xBF\x00"), ct: "text/plain; charset=utf-8"},
+
+ &exactSig{[]byte("GIF87a"), "image/gif"},
+ &exactSig{[]byte("GIF89a"), "image/gif"},
+ &exactSig{[]byte("\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"), "image/png"},
+ &exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"},
+ &exactSig{[]byte("BM"), "image/bmp"},
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
+ pat: []byte("RIFF\x00\x00\x00\x00WEBPVP"),
+ ct: "image/webp",
+ },
+ &exactSig{[]byte("\x00\x00\x01\x00"), "image/vnd.microsoft.icon"},
+ &exactSig{[]byte("\x4F\x67\x67\x53\x00"), "application/ogg"},
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
+ pat: []byte("RIFF\x00\x00\x00\x00WAVE"),
+ ct: "audio/wave",
+ },
+ &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"},
+ &exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"},
+ &exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"},
+ &exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"},
+
+ // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4.
+ //mp4Sig(0),
+
+ textSig(0), // should be last
+}
+
+type exactSig struct {
+ sig []byte
+ ct string
+}
+
+func (e *exactSig) match(data []byte, firstNonWS int) string {
+ if bytes.HasPrefix(data, e.sig) {
+ return e.ct
+ }
+ return ""
+}
+
+type maskedSig struct {
+ mask, pat []byte
+ skipWS bool
+ ct string
+}
+
+func (m *maskedSig) match(data []byte, firstNonWS int) string {
+ if m.skipWS {
+ data = data[firstNonWS:]
+ }
+ if len(data) < len(m.mask) {
+ return ""
+ }
+ for i, mask := range m.mask {
+ db := data[i] & mask
+ if db != m.pat[i] {
+ return ""
+ }
+ }
+ return m.ct
+}
+
+type htmlSig []byte
+
+func (h htmlSig) match(data []byte, firstNonWS int) string {
+ data = data[firstNonWS:]
+ if len(data) < len(h)+1 {
+ return ""
+ }
+ for i, b := range h {
+ db := data[i]
+ if 'A' <= b && b <= 'Z' {
+ db &= 0xDF
+ }
+ if b != db {
+ return ""
+ }
+ }
+ // Next byte must be space or right angle bracket.
+ if db := data[len(h)]; db != ' ' && db != '>' {
+ return ""
+ }
+ return "text/html; charset=utf-8"
+}
+
+type mp4Sig int
+
+func (mp4Sig) match(data []byte, firstNonWS int) string {
+ // c.f. section 6.1.
+ if len(data) < 8 {
+ return ""
+ }
+ boxSize := int(binary.BigEndian.Uint32(data[:4]))
+ if boxSize%4 != 0 || len(data) < boxSize {
+ return ""
+ }
+ if !bytes.Equal(data[4:8], []byte("ftyp")) {
+ return ""
+ }
+ for st := 8; st < boxSize; st += 4 {
+ if st == 12 {
+ // minor version number
+ continue
+ }
+ seg := string(data[st : st+3])
+ switch seg {
+ case "mp4", "iso", "M4V", "M4P", "M4B":
+ return "video/mp4"
+ /* The remainder are not in the spec.
+ case "M4A":
+ return "audio/mp4"
+ case "3gp":
+ return "video/3gpp"
+ case "jp2":
+ return "image/jp2" // JPEG 2000
+ */
+ }
+ }
+ return ""
+}
+
+type textSig int
+
+func (textSig) match(data []byte, firstNonWS int) string {
+ // c.f. section 5, step 4.
+ for _, b := range data[firstNonWS:] {
+ switch {
+ case 0x00 <= b && b <= 0x08,
+ b == 0x0B,
+ 0x0E <= b && b <= 0x1A,
+ 0x1C <= b && b <= 0x1F:
+ return ""
+ }
+ }
+ return "text/plain; charset=utf-8"
+}
diff --git a/src/pkg/net/http/sniff_test.go b/src/pkg/net/http/sniff_test.go
new file mode 100644
index 000000000..6efa8ce1c
--- /dev/null
+++ b/src/pkg/net/http/sniff_test.go
@@ -0,0 +1,137 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ . "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+var sniffTests = []struct {
+ desc string
+ data []byte
+ contentType string
+}{
+ // Some nonsense.
+ {"Empty", []byte{}, "text/plain; charset=utf-8"},
+ {"Binary", []byte{1, 2, 3}, "application/octet-stream"},
+
+ {"HTML document #1", []byte(`<HtMl><bOdY>blah blah blah</body></html>`), "text/html; charset=utf-8"},
+ {"HTML document #2", []byte(`<HTML></HTML>`), "text/html; charset=utf-8"},
+ {"HTML document #3 (leading whitespace)", []byte(` <!DOCTYPE HTML>...`), "text/html; charset=utf-8"},
+ {"HTML document #4 (leading CRLF)", []byte("\r\n<html>..."), "text/html; charset=utf-8"},
+
+ {"Plain text", []byte(`This is not HTML. It has ☃ though.`), "text/plain; charset=utf-8"},
+
+ {"XML", []byte("\n<?xml!"), "text/xml; charset=utf-8"},
+
+ // Image types.
+ {"GIF 87a", []byte(`GIF87a`), "image/gif"},
+ {"GIF 89a", []byte(`GIF89a...`), "image/gif"},
+
+ // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4.
+ //{"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"},
+ //{"MP4 audio", []byte("\x00\x00\x00\x20ftypM4A \x00\x00\x00\x00M4A mp42isom\x00\x00\x00\x00"), "audio/mp4"},
+}
+
+func TestDetectContentType(t *testing.T) {
+ for _, tt := range sniffTests {
+ ct := DetectContentType(tt.data)
+ if ct != tt.contentType {
+ t.Errorf("%v: DetectContentType = %q, want %q", tt.desc, ct, tt.contentType)
+ }
+ }
+}
+
+func TestServerContentType(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ i, _ := strconv.Atoi(r.FormValue("i"))
+ tt := sniffTests[i]
+ n, err := w.Write(tt.data)
+ if n != len(tt.data) || err != nil {
+ log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data))
+ }
+ }))
+ defer ts.Close()
+
+ for i, tt := range sniffTests {
+ resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i))
+ if err != nil {
+ t.Errorf("%v: %v", tt.desc, err)
+ continue
+ }
+ if ct := resp.Header.Get("Content-Type"); ct != tt.contentType {
+ t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, tt.contentType)
+ }
+ data, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("%v: reading body: %v", tt.desc, err)
+ } else if !bytes.Equal(data, tt.data) {
+ t.Errorf("%v: data is %q, want %q", tt.desc, data, tt.data)
+ }
+ resp.Body.Close()
+ }
+}
+
+func TestContentTypeWithCopy(t *testing.T) {
+ const (
+ input = "\n<html>\n\t<head>\n"
+ expected = "text/html; charset=utf-8"
+ )
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Use io.Copy from a bytes.Buffer to trigger ReadFrom.
+ buf := bytes.NewBuffer([]byte(input))
+ n, err := io.Copy(w, buf)
+ if int(n) != len(input) || err != nil {
+ t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
+ }
+ }))
+ defer ts.Close()
+
+ resp, err := Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if ct := resp.Header.Get("Content-Type"); ct != expected {
+ t.Errorf("Content-Type = %q, want %q", ct, expected)
+ }
+ data, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("reading body: %v", err)
+ } else if !bytes.Equal(data, []byte(input)) {
+ t.Errorf("data is %q, want %q", data, input)
+ }
+ resp.Body.Close()
+}
+
+func TestSniffWriteSize(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ size, _ := strconv.Atoi(r.FormValue("size"))
+ written, err := io.WriteString(w, strings.Repeat("a", size))
+ if err != nil {
+ t.Errorf("write of %d bytes: %v", size, err)
+ return
+ }
+ if written != size {
+ t.Errorf("write of %d bytes wrote %d bytes", size, written)
+ }
+ }))
+ defer ts.Close()
+ for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} {
+ _, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size))
+ if err != nil {
+ t.Fatalf("size %d: %v", size, err)
+ }
+ }
+}
diff --git a/src/pkg/net/http/status.go b/src/pkg/net/http/status.go
new file mode 100644
index 000000000..b6e2d65c6
--- /dev/null
+++ b/src/pkg/net/http/status.go
@@ -0,0 +1,106 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// HTTP status codes, defined in RFC 2616.
+const (
+ StatusContinue = 100
+ StatusSwitchingProtocols = 101
+
+ StatusOK = 200
+ StatusCreated = 201
+ StatusAccepted = 202
+ StatusNonAuthoritativeInfo = 203
+ StatusNoContent = 204
+ StatusResetContent = 205
+ StatusPartialContent = 206
+
+ StatusMultipleChoices = 300
+ StatusMovedPermanently = 301
+ StatusFound = 302
+ StatusSeeOther = 303
+ StatusNotModified = 304
+ StatusUseProxy = 305
+ StatusTemporaryRedirect = 307
+
+ StatusBadRequest = 400
+ StatusUnauthorized = 401
+ StatusPaymentRequired = 402
+ StatusForbidden = 403
+ StatusNotFound = 404
+ StatusMethodNotAllowed = 405
+ StatusNotAcceptable = 406
+ StatusProxyAuthRequired = 407
+ StatusRequestTimeout = 408
+ StatusConflict = 409
+ StatusGone = 410
+ StatusLengthRequired = 411
+ StatusPreconditionFailed = 412
+ StatusRequestEntityTooLarge = 413
+ StatusRequestURITooLong = 414
+ StatusUnsupportedMediaType = 415
+ StatusRequestedRangeNotSatisfiable = 416
+ StatusExpectationFailed = 417
+
+ StatusInternalServerError = 500
+ StatusNotImplemented = 501
+ StatusBadGateway = 502
+ StatusServiceUnavailable = 503
+ StatusGatewayTimeout = 504
+ StatusHTTPVersionNotSupported = 505
+)
+
+var statusText = map[int]string{
+ StatusContinue: "Continue",
+ StatusSwitchingProtocols: "Switching Protocols",
+
+ StatusOK: "OK",
+ StatusCreated: "Created",
+ StatusAccepted: "Accepted",
+ StatusNonAuthoritativeInfo: "Non-Authoritative Information",
+ StatusNoContent: "No Content",
+ StatusResetContent: "Reset Content",
+ StatusPartialContent: "Partial Content",
+
+ StatusMultipleChoices: "Multiple Choices",
+ StatusMovedPermanently: "Moved Permanently",
+ StatusFound: "Found",
+ StatusSeeOther: "See Other",
+ StatusNotModified: "Not Modified",
+ StatusUseProxy: "Use Proxy",
+ StatusTemporaryRedirect: "Temporary Redirect",
+
+ StatusBadRequest: "Bad Request",
+ StatusUnauthorized: "Unauthorized",
+ StatusPaymentRequired: "Payment Required",
+ StatusForbidden: "Forbidden",
+ StatusNotFound: "Not Found",
+ StatusMethodNotAllowed: "Method Not Allowed",
+ StatusNotAcceptable: "Not Acceptable",
+ StatusProxyAuthRequired: "Proxy Authentication Required",
+ StatusRequestTimeout: "Request Timeout",
+ StatusConflict: "Conflict",
+ StatusGone: "Gone",
+ StatusLengthRequired: "Length Required",
+ StatusPreconditionFailed: "Precondition Failed",
+ StatusRequestEntityTooLarge: "Request Entity Too Large",
+ StatusRequestURITooLong: "Request URI Too Long",
+ StatusUnsupportedMediaType: "Unsupported Media Type",
+ StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable",
+ StatusExpectationFailed: "Expectation Failed",
+
+ StatusInternalServerError: "Internal Server Error",
+ StatusNotImplemented: "Not Implemented",
+ StatusBadGateway: "Bad Gateway",
+ StatusServiceUnavailable: "Service Unavailable",
+ StatusGatewayTimeout: "Gateway Timeout",
+ StatusHTTPVersionNotSupported: "HTTP Version Not Supported",
+}
+
+// StatusText returns a text for the HTTP status code. It returns the empty
+// string if the code is unknown.
+func StatusText(code int) string {
+ return statusText[code]
+}
diff --git a/src/pkg/net/http/testdata/file b/src/pkg/net/http/testdata/file
new file mode 100644
index 000000000..11f11f9be
--- /dev/null
+++ b/src/pkg/net/http/testdata/file
@@ -0,0 +1 @@
+0123456789
diff --git a/src/pkg/net/http/testdata/index.html b/src/pkg/net/http/testdata/index.html
new file mode 100644
index 000000000..da8e1e93d
--- /dev/null
+++ b/src/pkg/net/http/testdata/index.html
@@ -0,0 +1 @@
+index.html says hello
diff --git a/src/pkg/net/http/testdata/style.css b/src/pkg/net/http/testdata/style.css
new file mode 100644
index 000000000..208d16d42
--- /dev/null
+++ b/src/pkg/net/http/testdata/style.css
@@ -0,0 +1 @@
+body {}
diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go
new file mode 100644
index 000000000..ef9564af9
--- /dev/null
+++ b/src/pkg/net/http/transfer.go
@@ -0,0 +1,630 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/textproto"
+ "strconv"
+ "strings"
+)
+
+// transferWriter inspects the fields of a user-supplied Request or Response,
+// sanitizes them without changing the user object and provides methods for
+// writing the respective header, body and trailer in wire format.
+type transferWriter struct {
+ Method string
+ Body io.Reader
+ BodyCloser io.Closer
+ ResponseToHEAD bool
+ ContentLength int64 // -1 means unknown, 0 means exactly none
+ Close bool
+ TransferEncoding []string
+ Trailer Header
+}
+
+func newTransferWriter(r interface{}) (t *transferWriter, err error) {
+ t = &transferWriter{}
+
+ // Extract relevant fields
+ atLeastHTTP11 := false
+ switch rr := r.(type) {
+ case *Request:
+ if rr.ContentLength != 0 && rr.Body == nil {
+ return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength)
+ }
+ t.Method = rr.Method
+ t.Body = rr.Body
+ t.BodyCloser = rr.Body
+ t.ContentLength = rr.ContentLength
+ t.Close = rr.Close
+ t.TransferEncoding = rr.TransferEncoding
+ t.Trailer = rr.Trailer
+ atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
+ if t.Body != nil && len(t.TransferEncoding) == 0 && atLeastHTTP11 {
+ if t.ContentLength == 0 {
+ // Test to see if it's actually zero or just unset.
+ var buf [1]byte
+ n, _ := io.ReadFull(t.Body, buf[:])
+ if n == 1 {
+ // Oh, guess there is data in this Body Reader after all.
+ // The ContentLength field just wasn't set.
+ // Stich the Body back together again, re-attaching our
+ // consumed byte.
+ t.ContentLength = -1
+ t.Body = io.MultiReader(bytes.NewBuffer(buf[:]), t.Body)
+ } else {
+ // Body is actually empty.
+ t.Body = nil
+ t.BodyCloser = nil
+ }
+ }
+ if t.ContentLength < 0 {
+ t.TransferEncoding = []string{"chunked"}
+ }
+ }
+ case *Response:
+ t.Method = rr.Request.Method
+ t.Body = rr.Body
+ t.BodyCloser = rr.Body
+ t.ContentLength = rr.ContentLength
+ t.Close = rr.Close
+ t.TransferEncoding = rr.TransferEncoding
+ t.Trailer = rr.Trailer
+ atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
+ t.ResponseToHEAD = noBodyExpected(rr.Request.Method)
+ }
+
+ // Sanitize Body,ContentLength,TransferEncoding
+ if t.ResponseToHEAD {
+ t.Body = nil
+ t.TransferEncoding = nil
+ // ContentLength is expected to hold Content-Length
+ if t.ContentLength < 0 {
+ return nil, ErrMissingContentLength
+ }
+ } else {
+ if !atLeastHTTP11 || t.Body == nil {
+ t.TransferEncoding = nil
+ }
+ if chunked(t.TransferEncoding) {
+ t.ContentLength = -1
+ } else if t.Body == nil { // no chunking, no body
+ t.ContentLength = 0
+ }
+ }
+
+ // Sanitize Trailer
+ if !chunked(t.TransferEncoding) {
+ t.Trailer = nil
+ }
+
+ return t, nil
+}
+
+func noBodyExpected(requestMethod string) bool {
+ return requestMethod == "HEAD"
+}
+
+func (t *transferWriter) shouldSendContentLength() bool {
+ if chunked(t.TransferEncoding) {
+ return false
+ }
+ if t.ContentLength > 0 {
+ return true
+ }
+ if t.ResponseToHEAD {
+ return true
+ }
+ // Many servers expect a Content-Length for these methods
+ if t.Method == "POST" || t.Method == "PUT" {
+ return true
+ }
+ if t.ContentLength == 0 && isIdentity(t.TransferEncoding) {
+ return true
+ }
+
+ return false
+}
+
+func (t *transferWriter) WriteHeader(w io.Writer) (err error) {
+ if t.Close {
+ _, err = io.WriteString(w, "Connection: close\r\n")
+ if err != nil {
+ return
+ }
+ }
+
+ // Write Content-Length and/or Transfer-Encoding whose values are a
+ // function of the sanitized field triple (Body, ContentLength,
+ // TransferEncoding)
+ if t.shouldSendContentLength() {
+ io.WriteString(w, "Content-Length: ")
+ _, err = io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n")
+ if err != nil {
+ return
+ }
+ } else if chunked(t.TransferEncoding) {
+ _, err = io.WriteString(w, "Transfer-Encoding: chunked\r\n")
+ if err != nil {
+ return
+ }
+ }
+
+ // Write Trailer header
+ if t.Trailer != nil {
+ // TODO: At some point, there should be a generic mechanism for
+ // writing long headers, using HTTP line splitting
+ io.WriteString(w, "Trailer: ")
+ needComma := false
+ for k := range t.Trailer {
+ k = CanonicalHeaderKey(k)
+ switch k {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ return &badStringError{"invalid Trailer key", k}
+ }
+ if needComma {
+ io.WriteString(w, ",")
+ }
+ io.WriteString(w, k)
+ needComma = true
+ }
+ _, err = io.WriteString(w, "\r\n")
+ }
+
+ return
+}
+
+func (t *transferWriter) WriteBody(w io.Writer) (err error) {
+ var ncopy int64
+
+ // Write body
+ if t.Body != nil {
+ if chunked(t.TransferEncoding) {
+ cw := newChunkedWriter(w)
+ _, err = io.Copy(cw, t.Body)
+ if err == nil {
+ err = cw.Close()
+ }
+ } else if t.ContentLength == -1 {
+ ncopy, err = io.Copy(w, t.Body)
+ } else {
+ ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength))
+ nextra, err := io.Copy(ioutil.Discard, t.Body)
+ if err != nil {
+ return err
+ }
+ ncopy += nextra
+ }
+ if err != nil {
+ return err
+ }
+ if err = t.BodyCloser.Close(); err != nil {
+ return err
+ }
+ }
+
+ if t.ContentLength != -1 && t.ContentLength != ncopy {
+ return fmt.Errorf("http: Request.ContentLength=%d with Body length %d",
+ t.ContentLength, ncopy)
+ }
+
+ // TODO(petar): Place trailer writer code here.
+ if chunked(t.TransferEncoding) {
+ // Last chunk, empty trailer
+ _, err = io.WriteString(w, "\r\n")
+ }
+
+ return
+}
+
+type transferReader struct {
+ // Input
+ Header Header
+ StatusCode int
+ RequestMethod string
+ ProtoMajor int
+ ProtoMinor int
+ // Output
+ Body io.ReadCloser
+ ContentLength int64
+ TransferEncoding []string
+ Close bool
+ Trailer Header
+}
+
+// bodyAllowedForStatus returns whether a given response status code
+// permits a body. See RFC2616, section 4.4.
+func bodyAllowedForStatus(status int) bool {
+ switch {
+ case status >= 100 && status <= 199:
+ return false
+ case status == 204:
+ return false
+ case status == 304:
+ return false
+ }
+ return true
+}
+
+// msg is *Request or *Response.
+func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
+ t := &transferReader{}
+
+ // Unify input
+ isResponse := false
+ switch rr := msg.(type) {
+ case *Response:
+ t.Header = rr.Header
+ t.StatusCode = rr.StatusCode
+ t.RequestMethod = rr.Request.Method
+ t.ProtoMajor = rr.ProtoMajor
+ t.ProtoMinor = rr.ProtoMinor
+ t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header)
+ isResponse = true
+ case *Request:
+ t.Header = rr.Header
+ t.ProtoMajor = rr.ProtoMajor
+ t.ProtoMinor = rr.ProtoMinor
+ // Transfer semantics for Requests are exactly like those for
+ // Responses with status code 200, responding to a GET method
+ t.StatusCode = 200
+ t.RequestMethod = "GET"
+ default:
+ panic("unexpected type")
+ }
+
+ // Default to HTTP/1.1
+ if t.ProtoMajor == 0 && t.ProtoMinor == 0 {
+ t.ProtoMajor, t.ProtoMinor = 1, 1
+ }
+
+ // Transfer encoding, content length
+ t.TransferEncoding, err = fixTransferEncoding(t.RequestMethod, t.Header)
+ if err != nil {
+ return err
+ }
+
+ t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
+ if err != nil {
+ return err
+ }
+
+ // Trailer
+ t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding)
+ if err != nil {
+ return err
+ }
+
+ // If there is no Content-Length or chunked Transfer-Encoding on a *Response
+ // and the status is not 1xx, 204 or 304, then the body is unbounded.
+ // See RFC2616, section 4.4.
+ switch msg.(type) {
+ case *Response:
+ if t.ContentLength == -1 &&
+ !chunked(t.TransferEncoding) &&
+ bodyAllowedForStatus(t.StatusCode) {
+ // Unbounded body.
+ t.Close = true
+ }
+ }
+
+ // Prepare body reader. ContentLength < 0 means chunked encoding
+ // or close connection when finished, since multipart is not supported yet
+ switch {
+ case chunked(t.TransferEncoding):
+ t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
+ case t.ContentLength >= 0:
+ // TODO: limit the Content-Length. This is an easy DoS vector.
+ t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close}
+ default:
+ // t.ContentLength < 0, i.e. "Content-Length" not mentioned in header
+ if t.Close {
+ // Close semantics (i.e. HTTP/1.0)
+ t.Body = &body{Reader: r, closing: t.Close}
+ } else {
+ // Persistent connection (i.e. HTTP/1.1)
+ t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close}
+ }
+ }
+
+ // Unify output
+ switch rr := msg.(type) {
+ case *Request:
+ rr.Body = t.Body
+ rr.ContentLength = t.ContentLength
+ rr.TransferEncoding = t.TransferEncoding
+ rr.Close = t.Close
+ rr.Trailer = t.Trailer
+ case *Response:
+ rr.Body = t.Body
+ rr.ContentLength = t.ContentLength
+ rr.TransferEncoding = t.TransferEncoding
+ rr.Close = t.Close
+ rr.Trailer = t.Trailer
+ }
+
+ return nil
+}
+
+// Checks whether chunked is part of the encodings stack
+func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
+
+// Checks whether the encoding is explicitly "identity".
+func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" }
+
+// Sanitize transfer encoding
+func fixTransferEncoding(requestMethod string, header Header) ([]string, error) {
+ raw, present := header["Transfer-Encoding"]
+ if !present {
+ return nil, nil
+ }
+
+ delete(header, "Transfer-Encoding")
+
+ // Head responses have no bodies, so the transfer encoding
+ // should be ignored.
+ if requestMethod == "HEAD" {
+ return nil, nil
+ }
+
+ encodings := strings.Split(raw[0], ",")
+ te := make([]string, 0, len(encodings))
+ // TODO: Even though we only support "identity" and "chunked"
+ // encodings, the loop below is designed with foresight. One
+ // invariant that must be maintained is that, if present,
+ // chunked encoding must always come first.
+ for _, encoding := range encodings {
+ encoding = strings.ToLower(strings.TrimSpace(encoding))
+ // "identity" encoding is not recored
+ if encoding == "identity" {
+ break
+ }
+ if encoding != "chunked" {
+ return nil, &badStringError{"unsupported transfer encoding", encoding}
+ }
+ te = te[0 : len(te)+1]
+ te[len(te)-1] = encoding
+ }
+ if len(te) > 1 {
+ return nil, &badStringError{"too many transfer encodings", strings.Join(te, ",")}
+ }
+ if len(te) > 0 {
+ // Chunked encoding trumps Content-Length. See RFC 2616
+ // Section 4.4. Currently len(te) > 0 implies chunked
+ // encoding.
+ delete(header, "Content-Length")
+ return te, nil
+ }
+
+ return nil, nil
+}
+
+// Determine the expected body length, using RFC 2616 Section 4.4. This
+// function is not a method, because ultimately it should be shared by
+// ReadResponse and ReadRequest.
+func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) {
+
+ // Logic based on response type or status
+ if noBodyExpected(requestMethod) {
+ return 0, nil
+ }
+ if status/100 == 1 {
+ return 0, nil
+ }
+ switch status {
+ case 204, 304:
+ return 0, nil
+ }
+
+ // Logic based on Transfer-Encoding
+ if chunked(te) {
+ return -1, nil
+ }
+
+ // Logic based on Content-Length
+ cl := strings.TrimSpace(header.Get("Content-Length"))
+ if cl != "" {
+ n, err := strconv.ParseInt(cl, 10, 64)
+ if err != nil || n < 0 {
+ return -1, &badStringError{"bad Content-Length", cl}
+ }
+ return n, nil
+ } else {
+ header.Del("Content-Length")
+ }
+
+ if !isResponse && requestMethod == "GET" {
+ // RFC 2616 doesn't explicitly permit nor forbid an
+ // entity-body on a GET request so we permit one if
+ // declared, but we default to 0 here (not -1 below)
+ // if there's no mention of a body.
+ return 0, nil
+ }
+
+ // Logic based on media type. The purpose of the following code is just
+ // to detect whether the unsupported "multipart/byteranges" is being
+ // used. A proper Content-Type parser is needed in the future.
+ if strings.Contains(strings.ToLower(header.Get("Content-Type")), "multipart/byteranges") {
+ return -1, ErrNotSupported
+ }
+
+ // Body-EOF logic based on other methods (like closing, or chunked coding)
+ return -1, nil
+}
+
+// Determine whether to hang up after sending a request and body, or
+// receiving a response and body
+// 'header' is the request headers
+func shouldClose(major, minor int, header Header) bool {
+ if major < 1 {
+ return true
+ } else if major == 1 && minor == 0 {
+ if !strings.Contains(strings.ToLower(header.Get("Connection")), "keep-alive") {
+ return true
+ }
+ return false
+ } else {
+ // TODO: Should split on commas, toss surrounding white space,
+ // and check each field.
+ if strings.ToLower(header.Get("Connection")) == "close" {
+ header.Del("Connection")
+ return true
+ }
+ }
+ return false
+}
+
+// Parse the trailer header
+func fixTrailer(header Header, te []string) (Header, error) {
+ raw := header.Get("Trailer")
+ if raw == "" {
+ return nil, nil
+ }
+
+ header.Del("Trailer")
+ trailer := make(Header)
+ keys := strings.Split(raw, ",")
+ for _, key := range keys {
+ key = CanonicalHeaderKey(strings.TrimSpace(key))
+ switch key {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ return nil, &badStringError{"bad trailer key", key}
+ }
+ trailer.Del(key)
+ }
+ if len(trailer) == 0 {
+ return nil, nil
+ }
+ if !chunked(te) {
+ // Trailer and no chunking
+ return nil, ErrUnexpectedTrailer
+ }
+ return trailer, nil
+}
+
+// body turns a Reader into a ReadCloser.
+// Close ensures that the body has been fully read
+// and then reads the trailer if necessary.
+type body struct {
+ io.Reader
+ hdr interface{} // non-nil (Response or Request) value means read trailer
+ r *bufio.Reader // underlying wire-format reader for the trailer
+ closing bool // is the connection to be closed after reading body?
+ closed bool
+
+ res *response // response writer for server requests, else nil
+}
+
+// ErrBodyReadAfterClose is returned when reading a Request Body after
+// the body has been closed. This typically happens when the body is
+// read after an HTTP Handler calls WriteHeader or Write on its
+// ResponseWriter.
+var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed request Body")
+
+func (b *body) Read(p []byte) (n int, err error) {
+ if b.closed {
+ return 0, ErrBodyReadAfterClose
+ }
+ n, err = b.Reader.Read(p)
+
+ // Read the final trailer once we hit EOF.
+ if err == io.EOF && b.hdr != nil {
+ if e := b.readTrailer(); e != nil {
+ err = e
+ }
+ b.hdr = nil
+ }
+ return n, err
+}
+
+var (
+ singleCRLF = []byte("\r\n")
+ doubleCRLF = []byte("\r\n\r\n")
+)
+
+func seeUpcomingDoubleCRLF(r *bufio.Reader) bool {
+ for peekSize := 4; ; peekSize++ {
+ // This loop stops when Peek returns an error,
+ // which it does when r's buffer has been filled.
+ buf, err := r.Peek(peekSize)
+ if bytes.HasSuffix(buf, doubleCRLF) {
+ return true
+ }
+ if err != nil {
+ break
+ }
+ }
+ return false
+}
+
+func (b *body) readTrailer() error {
+ // The common case, since nobody uses trailers.
+ buf, _ := b.r.Peek(2)
+ if bytes.Equal(buf, singleCRLF) {
+ b.r.ReadByte()
+ b.r.ReadByte()
+ return nil
+ }
+
+ // Make sure there's a header terminator coming up, to prevent
+ // a DoS with an unbounded size Trailer. It's not easy to
+ // slip in a LimitReader here, as textproto.NewReader requires
+ // a concrete *bufio.Reader. Also, we can't get all the way
+ // back up to our conn's LimitedReader that *might* be backing
+ // this bufio.Reader. Instead, a hack: we iteratively Peek up
+ // to the bufio.Reader's max size, looking for a double CRLF.
+ // This limits the trailer to the underlying buffer size, typically 4kB.
+ if !seeUpcomingDoubleCRLF(b.r) {
+ return errors.New("http: suspiciously long trailer after chunked body")
+ }
+
+ hdr, err := textproto.NewReader(b.r).ReadMIMEHeader()
+ if err != nil {
+ return err
+ }
+ switch rr := b.hdr.(type) {
+ case *Request:
+ rr.Trailer = Header(hdr)
+ case *Response:
+ rr.Trailer = Header(hdr)
+ }
+ return nil
+}
+
+func (b *body) Close() error {
+ if b.closed {
+ return nil
+ }
+ defer func() {
+ b.closed = true
+ }()
+ if b.hdr == nil && b.closing {
+ // no trailer and closing the connection next.
+ // no point in reading to EOF.
+ return nil
+ }
+
+ // In a server request, don't continue reading from the client
+ // if we've already hit the maximum body size set by the
+ // handler. If this is set, that also means the TCP connection
+ // is about to be closed, so getting to the next HTTP request
+ // in the stream is not necessary.
+ if b.res != nil && b.res.requestBodyLimitHit {
+ return nil
+ }
+
+ // Fully consume the body, which will also lead to us reading
+ // the trailer headers after the body, if present.
+ if _, err := io.Copy(ioutil.Discard, b); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go
new file mode 100644
index 000000000..4de070f01
--- /dev/null
+++ b/src/pkg/net/http/transport.go
@@ -0,0 +1,736 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP client implementation. See RFC 2616.
+//
+// This is the low-level Transport implementation of RoundTripper.
+// The high-level interface is in client.go.
+
+package http
+
+import (
+ "bufio"
+ "compress/gzip"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+)
+
+// DefaultTransport is the default implementation of Transport and is
+// used by DefaultClient. It establishes a new network connection for
+// each call to Do and uses HTTP proxies as directed by the
+// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy)
+// environment variables.
+var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment}
+
+// DefaultMaxIdleConnsPerHost is the default value of Transport's
+// MaxIdleConnsPerHost.
+const DefaultMaxIdleConnsPerHost = 2
+
+// Transport is an implementation of RoundTripper that supports http,
+// https, and http proxies (for either http or https with CONNECT).
+// Transport can also cache connections for future re-use.
+type Transport struct {
+ lk sync.Mutex
+ idleConn map[string][]*persistConn
+ altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
+
+ // TODO: tunable on global max cached connections
+ // TODO: tunable on timeout on cached connections
+ // TODO: optional pipelining
+
+ // Proxy specifies a function to return a proxy for a given
+ // Request. If the function returns a non-nil error, the
+ // request is aborted with the provided error.
+ // If Proxy is nil or returns a nil *URL, no proxy is used.
+ Proxy func(*Request) (*url.URL, error)
+
+ // Dial specifies the dial function for creating TCP
+ // connections.
+ // If Dial is nil, net.Dial is used.
+ Dial func(net, addr string) (c net.Conn, err error)
+
+ // TLSClientConfig specifies the TLS configuration to use with
+ // tls.Client. If nil, the default configuration is used.
+ TLSClientConfig *tls.Config
+
+ DisableKeepAlives bool
+ DisableCompression bool
+
+ // MaxIdleConnsPerHost, if non-zero, controls the maximum idle
+ // (keep-alive) to keep to keep per-host. If zero,
+ // DefaultMaxIdleConnsPerHost is used.
+ MaxIdleConnsPerHost int
+}
+
+// ProxyFromEnvironment returns the URL of the proxy to use for a
+// given request, as indicated by the environment variables
+// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy).
+// Either URL or an error is returned.
+func ProxyFromEnvironment(req *Request) (*url.URL, error) {
+ proxy := getenvEitherCase("HTTP_PROXY")
+ if proxy == "" {
+ return nil, nil
+ }
+ if !useProxy(canonicalAddr(req.URL)) {
+ return nil, nil
+ }
+ proxyURL, err := url.ParseRequest(proxy)
+ if err != nil {
+ return nil, errors.New("invalid proxy address")
+ }
+ if proxyURL.Host == "" {
+ proxyURL, err = url.ParseRequest("http://" + proxy)
+ if err != nil {
+ return nil, errors.New("invalid proxy address")
+ }
+ }
+ return proxyURL, nil
+}
+
+// ProxyURL returns a proxy function (for use in a Transport)
+// that always returns the same URL.
+func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
+ return func(*Request) (*url.URL, error) {
+ return fixedURL, nil
+ }
+}
+
+// transportRequest is a wrapper around a *Request that adds
+// optional extra headers to write.
+type transportRequest struct {
+ *Request // original request, not to be mutated
+ extra Header // extra headers to write, or nil
+}
+
+func (tr *transportRequest) extraHeaders() Header {
+ if tr.extra == nil {
+ tr.extra = make(Header)
+ }
+ return tr.extra
+}
+
+// RoundTrip implements the RoundTripper interface.
+func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
+ if req.URL == nil {
+ return nil, errors.New("http: nil Request.URL")
+ }
+ if req.Header == nil {
+ return nil, errors.New("http: nil Request.Header")
+ }
+ if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
+ t.lk.Lock()
+ var rt RoundTripper
+ if t.altProto != nil {
+ rt = t.altProto[req.URL.Scheme]
+ }
+ t.lk.Unlock()
+ if rt == nil {
+ return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
+ }
+ return rt.RoundTrip(req)
+ }
+ treq := &transportRequest{Request: req}
+ cm, err := t.connectMethodForRequest(treq)
+ if err != nil {
+ return nil, err
+ }
+
+ // Get the cached or newly-created connection to either the
+ // host (for http or https), the http proxy, or the http proxy
+ // pre-CONNECTed to https server. In any case, we'll be ready
+ // to send it requests.
+ pconn, err := t.getConn(cm)
+ if err != nil {
+ return nil, err
+ }
+
+ return pconn.roundTrip(treq)
+}
+
+// RegisterProtocol registers a new protocol with scheme.
+// The Transport will pass requests using the given scheme to rt.
+// It is rt's responsibility to simulate HTTP request semantics.
+//
+// RegisterProtocol can be used by other packages to provide
+// implementations of protocol schemes like "ftp" or "file".
+func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
+ if scheme == "http" || scheme == "https" {
+ panic("protocol " + scheme + " already registered")
+ }
+ t.lk.Lock()
+ defer t.lk.Unlock()
+ if t.altProto == nil {
+ t.altProto = make(map[string]RoundTripper)
+ }
+ if _, exists := t.altProto[scheme]; exists {
+ panic("protocol " + scheme + " already registered")
+ }
+ t.altProto[scheme] = rt
+}
+
+// CloseIdleConnections closes any connections which were previously
+// connected from previous requests but are now sitting idle in
+// a "keep-alive" state. It does not interrupt any connections currently
+// in use.
+func (t *Transport) CloseIdleConnections() {
+ t.lk.Lock()
+ defer t.lk.Unlock()
+ if t.idleConn == nil {
+ return
+ }
+ for _, conns := range t.idleConn {
+ for _, pconn := range conns {
+ pconn.close()
+ }
+ }
+ t.idleConn = nil
+}
+
+//
+// Private implementation past this point.
+//
+
+func getenvEitherCase(k string) string {
+ if v := os.Getenv(strings.ToUpper(k)); v != "" {
+ return v
+ }
+ return os.Getenv(strings.ToLower(k))
+}
+
+func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, error) {
+ cm := &connectMethod{
+ targetScheme: treq.URL.Scheme,
+ targetAddr: canonicalAddr(treq.URL),
+ }
+ if t.Proxy != nil {
+ var err error
+ cm.proxyURL, err = t.Proxy(treq.Request)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return cm, nil
+}
+
+// proxyAuth returns the Proxy-Authorization header to set
+// on requests, if applicable.
+func (cm *connectMethod) proxyAuth() string {
+ if cm.proxyURL == nil {
+ return ""
+ }
+ if u := cm.proxyURL.User; u != nil {
+ return "Basic " + base64.URLEncoding.EncodeToString([]byte(u.String()))
+ }
+ return ""
+}
+
+func (t *Transport) putIdleConn(pconn *persistConn) {
+ t.lk.Lock()
+ defer t.lk.Unlock()
+ if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
+ pconn.close()
+ return
+ }
+ if pconn.isBroken() {
+ return
+ }
+ key := pconn.cacheKey
+ max := t.MaxIdleConnsPerHost
+ if max == 0 {
+ max = DefaultMaxIdleConnsPerHost
+ }
+ if len(t.idleConn[key]) >= max {
+ pconn.close()
+ return
+ }
+ t.idleConn[key] = append(t.idleConn[key], pconn)
+}
+
+func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
+ t.lk.Lock()
+ defer t.lk.Unlock()
+ if t.idleConn == nil {
+ t.idleConn = make(map[string][]*persistConn)
+ }
+ key := cm.String()
+ for {
+ pconns, ok := t.idleConn[key]
+ if !ok {
+ return nil
+ }
+ if len(pconns) == 1 {
+ pconn = pconns[0]
+ delete(t.idleConn, key)
+ } else {
+ // 2 or more cached connections; pop last
+ // TODO: queue?
+ pconn = pconns[len(pconns)-1]
+ t.idleConn[key] = pconns[0 : len(pconns)-1]
+ }
+ if !pconn.isBroken() {
+ return
+ }
+ }
+ return
+}
+
+func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
+ if t.Dial != nil {
+ return t.Dial(network, addr)
+ }
+ return net.Dial(network, addr)
+}
+
+// getConn dials and creates a new persistConn to the target as
+// specified in the connectMethod. This includes doing a proxy CONNECT
+// and/or setting up TLS. If this doesn't return an error, the persistConn
+// is ready to write requests to.
+func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
+ if pc := t.getIdleConn(cm); pc != nil {
+ return pc, nil
+ }
+
+ conn, err := t.dial("tcp", cm.addr())
+ if err != nil {
+ if cm.proxyURL != nil {
+ err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
+ }
+ return nil, err
+ }
+
+ pa := cm.proxyAuth()
+
+ pconn := &persistConn{
+ t: t,
+ cacheKey: cm.String(),
+ conn: conn,
+ reqch: make(chan requestAndChan, 50),
+ }
+
+ switch {
+ case cm.proxyURL == nil:
+ // Do nothing.
+ case cm.targetScheme == "http":
+ pconn.isProxy = true
+ if pa != "" {
+ pconn.mutateHeaderFunc = func(h Header) {
+ h.Set("Proxy-Authorization", pa)
+ }
+ }
+ case cm.targetScheme == "https":
+ connectReq := &Request{
+ Method: "CONNECT",
+ URL: &url.URL{Opaque: cm.targetAddr},
+ Host: cm.targetAddr,
+ Header: make(Header),
+ }
+ if pa != "" {
+ connectReq.Header.Set("Proxy-Authorization", pa)
+ }
+ connectReq.Write(conn)
+
+ // Read response.
+ // Okay to use and discard buffered reader here, because
+ // TLS server will not speak until spoken to.
+ br := bufio.NewReader(conn)
+ resp, err := ReadResponse(br, connectReq)
+ if err != nil {
+ conn.Close()
+ return nil, err
+ }
+ if resp.StatusCode != 200 {
+ f := strings.SplitN(resp.Status, " ", 2)
+ conn.Close()
+ return nil, errors.New(f[1])
+ }
+ }
+
+ if cm.targetScheme == "https" {
+ // Initiate TLS and check remote host name against certificate.
+ conn = tls.Client(conn, t.TLSClientConfig)
+ if err = conn.(*tls.Conn).Handshake(); err != nil {
+ return nil, err
+ }
+ if t.TLSClientConfig == nil || !t.TLSClientConfig.InsecureSkipVerify {
+ if err = conn.(*tls.Conn).VerifyHostname(cm.tlsHost()); err != nil {
+ return nil, err
+ }
+ }
+ pconn.conn = conn
+ }
+
+ pconn.br = bufio.NewReader(pconn.conn)
+ pconn.bw = bufio.NewWriter(pconn.conn)
+ go pconn.readLoop()
+ return pconn, nil
+}
+
+// useProxy returns true if requests to addr should use a proxy,
+// according to the NO_PROXY or no_proxy environment variable.
+// addr is always a canonicalAddr with a host and port.
+func useProxy(addr string) bool {
+ if len(addr) == 0 {
+ return true
+ }
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return false
+ }
+ if host == "localhost" {
+ return false
+ }
+ if ip := net.ParseIP(host); ip != nil {
+ if ip.IsLoopback() {
+ return false
+ }
+ }
+
+ no_proxy := getenvEitherCase("NO_PROXY")
+ if no_proxy == "*" {
+ return false
+ }
+
+ addr = strings.ToLower(strings.TrimSpace(addr))
+ if hasPort(addr) {
+ addr = addr[:strings.LastIndex(addr, ":")]
+ }
+
+ for _, p := range strings.Split(no_proxy, ",") {
+ p = strings.ToLower(strings.TrimSpace(p))
+ if len(p) == 0 {
+ continue
+ }
+ if hasPort(p) {
+ p = p[:strings.LastIndex(p, ":")]
+ }
+ if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) {
+ return false
+ }
+ }
+ return true
+}
+
+// connectMethod is the map key (in its String form) for keeping persistent
+// TCP connections alive for subsequent HTTP requests.
+//
+// A connect method may be of the following types:
+//
+// Cache key form Description
+// ----------------- -------------------------
+// ||http|foo.com http directly to server, no proxy
+// ||https|foo.com https directly to server, no proxy
+// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com
+// http://proxy.com|http http to proxy, http to anywhere after that
+//
+// Note: no support to https to the proxy yet.
+//
+type connectMethod struct {
+ proxyURL *url.URL // nil for no proxy, else full proxy URL
+ targetScheme string // "http" or "https"
+ targetAddr string // Not used if proxy + http targetScheme (4th example in table)
+}
+
+func (ck *connectMethod) String() string {
+ proxyStr := ""
+ if ck.proxyURL != nil {
+ proxyStr = ck.proxyURL.String()
+ }
+ return strings.Join([]string{proxyStr, ck.targetScheme, ck.targetAddr}, "|")
+}
+
+// addr returns the first hop "host:port" to which we need to TCP connect.
+func (cm *connectMethod) addr() string {
+ if cm.proxyURL != nil {
+ return canonicalAddr(cm.proxyURL)
+ }
+ return cm.targetAddr
+}
+
+// tlsHost returns the host name to match against the peer's
+// TLS certificate.
+func (cm *connectMethod) tlsHost() string {
+ h := cm.targetAddr
+ if hasPort(h) {
+ h = h[:strings.LastIndex(h, ":")]
+ }
+ return h
+}
+
+// persistConn wraps a connection, usually a persistent one
+// (but may be used for non-keep-alive requests as well)
+type persistConn struct {
+ t *Transport
+ cacheKey string // its connectMethod.String()
+ conn net.Conn
+ br *bufio.Reader // from conn
+ bw *bufio.Writer // to conn
+ reqch chan requestAndChan // written by roundTrip(); read by readLoop()
+ isProxy bool
+
+ // mutateHeaderFunc is an optional func to modify extra
+ // headers on each outbound request before it's written. (the
+ // original Request given to RoundTrip is not modified)
+ mutateHeaderFunc func(Header)
+
+ lk sync.Mutex // guards numExpectedResponses and broken
+ numExpectedResponses int
+ broken bool // an error has happened on this connection; marked broken so it's not reused.
+}
+
+func (pc *persistConn) isBroken() bool {
+ pc.lk.Lock()
+ defer pc.lk.Unlock()
+ return pc.broken
+}
+
+var remoteSideClosedFunc func(error) bool // or nil to use default
+
+func remoteSideClosed(err error) bool {
+ if err == io.EOF {
+ return true
+ }
+ if remoteSideClosedFunc != nil {
+ return remoteSideClosedFunc(err)
+ }
+ return false
+}
+
+func (pc *persistConn) readLoop() {
+ alive := true
+ var lastbody io.ReadCloser // last response body, if any, read on this connection
+
+ for alive {
+ pb, err := pc.br.Peek(1)
+
+ pc.lk.Lock()
+ if pc.numExpectedResponses == 0 {
+ pc.closeLocked()
+ pc.lk.Unlock()
+ if len(pb) > 0 {
+ log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v",
+ string(pb), err)
+ }
+ return
+ }
+ pc.lk.Unlock()
+
+ rc := <-pc.reqch
+
+ // Advance past the previous response's body, if the
+ // caller hasn't done so.
+ if lastbody != nil {
+ lastbody.Close() // assumed idempotent
+ lastbody = nil
+ }
+ resp, err := ReadResponse(pc.br, rc.req)
+
+ if err == nil {
+ hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
+ if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" {
+ resp.Header.Del("Content-Encoding")
+ resp.Header.Del("Content-Length")
+ resp.ContentLength = -1
+ gzReader, zerr := gzip.NewReader(resp.Body)
+ if zerr != nil {
+ pc.close()
+ err = zerr
+ } else {
+ resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body}
+ }
+ }
+ resp.Body = &bodyEOFSignal{body: resp.Body}
+ }
+
+ if err != nil || resp.Close || rc.req.Close {
+ alive = false
+ }
+
+ hasBody := resp != nil && resp.ContentLength != 0
+ var waitForBodyRead chan bool
+ if alive {
+ if hasBody {
+ lastbody = resp.Body
+ waitForBodyRead = make(chan bool)
+ resp.Body.(*bodyEOFSignal).fn = func() {
+ pc.t.putIdleConn(pc)
+ waitForBodyRead <- true
+ }
+ } else {
+ // When there's no response body, we immediately
+ // reuse the TCP connection (putIdleConn), but
+ // we need to prevent ClientConn.Read from
+ // closing the Response.Body on the next
+ // loop, otherwise it might close the body
+ // before the client code has had a chance to
+ // read it (even though it'll just be 0, EOF).
+ lastbody = nil
+
+ pc.t.putIdleConn(pc)
+ }
+ }
+
+ rc.ch <- responseAndError{resp, err}
+
+ // Wait for the just-returned response body to be fully consumed
+ // before we race and peek on the underlying bufio reader.
+ if waitForBodyRead != nil {
+ <-waitForBodyRead
+ }
+ }
+}
+
+type responseAndError struct {
+ res *Response
+ err error
+}
+
+type requestAndChan struct {
+ req *Request
+ ch chan responseAndError
+
+ // did the Transport (as opposed to the client code) add an
+ // Accept-Encoding gzip header? only if it we set it do
+ // we transparently decode the gzip.
+ addedGzip bool
+}
+
+func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
+ if pc.mutateHeaderFunc != nil {
+ pc.mutateHeaderFunc(req.extraHeaders())
+ }
+
+ // Ask for a compressed version if the caller didn't set their
+ // own value for Accept-Encoding. We only attempted to
+ // uncompress the gzip stream if we were the layer that
+ // requested it.
+ requestedGzip := false
+ if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
+ // Request gzip only, not deflate. Deflate is ambiguous and
+ // not as universally supported anyway.
+ // See: http://www.gzip.org/zlib/zlib_faq.html#faq38
+ requestedGzip = true
+ req.extraHeaders().Set("Accept-Encoding", "gzip")
+ }
+
+ pc.lk.Lock()
+ pc.numExpectedResponses++
+ pc.lk.Unlock()
+
+ err = req.Request.write(pc.bw, pc.isProxy, req.extra)
+ if err != nil {
+ pc.close()
+ return
+ }
+ pc.bw.Flush()
+
+ ch := make(chan responseAndError, 1)
+ pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
+ re := <-ch
+ pc.lk.Lock()
+ pc.numExpectedResponses--
+ pc.lk.Unlock()
+
+ return re.res, re.err
+}
+
+func (pc *persistConn) close() {
+ pc.lk.Lock()
+ defer pc.lk.Unlock()
+ pc.closeLocked()
+}
+
+func (pc *persistConn) closeLocked() {
+ pc.broken = true
+ pc.conn.Close()
+ pc.mutateHeaderFunc = nil
+}
+
+var portMap = map[string]string{
+ "http": "80",
+ "https": "443",
+}
+
+// canonicalAddr returns url.Host but always with a ":port" suffix
+func canonicalAddr(url *url.URL) string {
+ addr := url.Host
+ if !hasPort(addr) {
+ return addr + ":" + portMap[url.Scheme]
+ }
+ return addr
+}
+
+func responseIsKeepAlive(res *Response) bool {
+ // TODO: implement. for now just always shutting down the connection.
+ return false
+}
+
+// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
+// once, right before the final Read() or Close() call returns, but after
+// EOF has been seen.
+type bodyEOFSignal struct {
+ body io.ReadCloser
+ fn func()
+ isClosed bool
+}
+
+func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
+ n, err = es.body.Read(p)
+ if es.isClosed && n > 0 {
+ panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725")
+ }
+ if err == io.EOF && es.fn != nil {
+ es.fn()
+ es.fn = nil
+ }
+ return
+}
+
+func (es *bodyEOFSignal) Close() (err error) {
+ if es.isClosed {
+ return nil
+ }
+ es.isClosed = true
+ err = es.body.Close()
+ if err == nil && es.fn != nil {
+ es.fn()
+ es.fn = nil
+ }
+ return
+}
+
+type readFirstCloseBoth struct {
+ io.ReadCloser
+ io.Closer
+}
+
+func (r *readFirstCloseBoth) Close() error {
+ if err := r.ReadCloser.Close(); err != nil {
+ r.Closer.Close()
+ return err
+ }
+ if err := r.Closer.Close(); err != nil {
+ return err
+ }
+ return nil
+}
+
+// discardOnCloseReadCloser consumes all its input on Close.
+type discardOnCloseReadCloser struct {
+ io.ReadCloser
+}
+
+func (d *discardOnCloseReadCloser) Close() error {
+ io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed
+ return d.ReadCloser.Close()
+}
diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go
new file mode 100644
index 000000000..321da52e2
--- /dev/null
+++ b/src/pkg/net/http/transport_test.go
@@ -0,0 +1,695 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for transport.go
+
+package http_test
+
+import (
+ "bytes"
+ "compress/gzip"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "io/ioutil"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
+// and then verify that the final 2 responses get errors back.
+
+// hostPortHandler writes back the client's "host:port".
+var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.FormValue("close") == "true" {
+ w.Header().Set("Connection", "close")
+ }
+ w.Write([]byte(r.RemoteAddr))
+})
+
+// Two subsequent requests and verify their response is the same.
+// The response from the server is our own IP:port
+func TestTransportKeepAlives(t *testing.T) {
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ for _, disableKeepAlive := range []bool{false, true} {
+ tr := &Transport{DisableKeepAlives: disableKeepAlive}
+ c := &Client{Transport: tr}
+
+ fetch := func(n int) string {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != disableKeepAlive {
+ t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ disableKeepAlive, bodiesDiffer, body1, body2)
+ }
+ }
+}
+
+func TestTransportConnectionCloseOnResponse(t *testing.T) {
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ for _, connectionClose := range []bool{false, true} {
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ fetch := func(n int) string {
+ req := new(Request)
+ var err error
+ req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
+ if err != nil {
+ t.Fatalf("URL parse error: %v", err)
+ }
+ req.Method = "GET"
+ req.Proto = "HTTP/1.1"
+ req.ProtoMajor = 1
+ req.ProtoMinor = 1
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ defer res.Body.Close()
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != connectionClose {
+ t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ connectionClose, bodiesDiffer, body1, body2)
+ }
+ }
+}
+
+func TestTransportConnectionCloseOnRequest(t *testing.T) {
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ for _, connectionClose := range []bool{false, true} {
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ fetch := func(n int) string {
+ req := new(Request)
+ var err error
+ req.URL, err = url.Parse(ts.URL)
+ if err != nil {
+ t.Fatalf("URL parse error: %v", err)
+ }
+ req.Method = "GET"
+ req.Proto = "HTTP/1.1"
+ req.ProtoMajor = 1
+ req.ProtoMinor = 1
+ req.Close = connectionClose
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != connectionClose {
+ t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ connectionClose, bodiesDiffer, body1, body2)
+ }
+ }
+}
+
+func TestTransportIdleCacheKeys(t *testing.T) {
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ c := &Client{Transport: tr}
+
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ ioutil.ReadAll(resp.Body)
+
+ keys := tr.IdleConnKeysForTesting()
+ if e, g := 1, len(keys); e != g {
+ t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
+ t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
+ }
+
+ tr.CloseIdleConnections()
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+ }
+}
+
+func TestTransportMaxPerHostIdleConns(t *testing.T) {
+ resch := make(chan string)
+ gotReq := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotReq <- true
+ msg := <-resch
+ _, err := w.Write([]byte(msg))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ }))
+ defer ts.Close()
+ maxIdleConns := 2
+ tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns}
+ c := &Client{Transport: tr}
+
+ // Start 3 outstanding requests and wait for the server to get them.
+ // Their responses will hang until we we write to resch, though.
+ donech := make(chan bool)
+ doReq := func() {
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ _, err = ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("ReadAll: %v", err)
+ }
+ donech <- true
+ }
+ go doReq()
+ <-gotReq
+ go doReq()
+ <-gotReq
+ go doReq()
+ <-gotReq
+
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ resch <- "res1"
+ <-donech
+ keys := tr.IdleConnKeysForTesting()
+ if e, g := 1, len(keys); e != g {
+ t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
+ }
+ cacheKey := "|http|" + ts.Listener.Addr().String()
+ if keys[0] != cacheKey {
+ t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
+ }
+ if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g {
+ t.Errorf("after first response, expected %d idle conns; got %d", e, g)
+ }
+
+ resch <- "res2"
+ <-donech
+ if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g {
+ t.Errorf("after second response, expected %d idle conns; got %d", e, g)
+ }
+
+ resch <- "res3"
+ <-donech
+ if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g {
+ t.Errorf("after third response, still expected %d idle conns; got %d", e, g)
+ }
+}
+
+func TestTransportServerClosingUnexpectedly(t *testing.T) {
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ fetch := func(n, retries int) string {
+ condFatalf := func(format string, arg ...interface{}) {
+ if retries <= 0 {
+ t.Fatalf(format, arg...)
+ }
+ t.Logf("retrying shortly after expected error: "+format, arg...)
+ time.Sleep(time.Second / time.Duration(retries))
+ }
+ for retries >= 0 {
+ retries--
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ condFatalf("error in req #%d, GET: %v", n, err)
+ continue
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ condFatalf("error in req #%d, ReadAll: %v", n, err)
+ continue
+ }
+ res.Body.Close()
+ return string(body)
+ }
+ panic("unreachable")
+ }
+
+ body1 := fetch(1, 0)
+ body2 := fetch(2, 0)
+
+ ts.CloseClientConnections() // surprise!
+
+ // This test has an expected race. Sleeping for 25 ms prevents
+ // it on most fast machines, causing the next fetch() call to
+ // succeed quickly. But if we do get errors, fetch() will retry 5
+ // times with some delays between.
+ time.Sleep(25 * time.Millisecond)
+
+ body3 := fetch(3, 5)
+
+ if body1 != body2 {
+ t.Errorf("expected body1 and body2 to be equal")
+ }
+ if body2 == body3 {
+ t.Errorf("expected body2 and body3 to be different")
+ }
+}
+
+// Test for http://golang.org/issue/2616 (appropriate issue number)
+// This fails pretty reliably with GOMAXPROCS=100 or something high.
+func TestStressSurpriseServerCloses(t *testing.T) {
+ if testing.Short() {
+ t.Logf("skipping test in short mode")
+ return
+ }
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "5")
+ w.Header().Set("Content-Type", "text/plain")
+ w.Write([]byte("Hello"))
+ w.(Flusher).Flush()
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Flush()
+ conn.Close()
+ }))
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ c := &Client{Transport: tr}
+
+ // Do a bunch of traffic from different goroutines. Send to activityc
+ // after each request completes, regardless of whether it failed.
+ const (
+ numClients = 50
+ reqsPerClient = 250
+ )
+ activityc := make(chan bool)
+ for i := 0; i < numClients; i++ {
+ go func() {
+ for i := 0; i < reqsPerClient; i++ {
+ res, err := c.Get(ts.URL)
+ if err == nil {
+ // We expect errors since the server is
+ // hanging up on us after telling us to
+ // send more requests, so we don't
+ // actually care what the error is.
+ // But we want to close the body in cases
+ // where we won the race.
+ res.Body.Close()
+ }
+ activityc <- true
+ }
+ }()
+ }
+
+ // Make sure all the request come back, one way or another.
+ for i := 0; i < numClients*reqsPerClient; i++ {
+ select {
+ case <-activityc:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile")
+ }
+ }
+}
+
+// TestTransportHeadResponses verifies that we deal with Content-Lengths
+// with no bodies properly
+func TestTransportHeadResponses(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ panic("expected HEAD; got " + r.Method)
+ }
+ w.Header().Set("Content-Length", "123")
+ w.WriteHeader(200)
+ }))
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ c := &Client{Transport: tr}
+ for i := 0; i < 2; i++ {
+ res, err := c.Head(ts.URL)
+ if err != nil {
+ t.Errorf("error on loop %d: %v", i, err)
+ }
+ if e, g := "123", res.Header.Get("Content-Length"); e != g {
+ t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
+ }
+ if e, g := int64(0), res.ContentLength; e != g {
+ t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
+ }
+ }
+}
+
+// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
+// on responses to HEAD requests.
+func TestTransportHeadChunkedResponse(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ panic("expected HEAD; got " + r.Method)
+ }
+ w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
+ w.Header().Set("x-client-ipport", r.RemoteAddr)
+ w.WriteHeader(200)
+ }))
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ c := &Client{Transport: tr}
+
+ res1, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatalf("request 1 error: %v", err)
+ }
+ res2, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatalf("request 2 error: %v", err)
+ }
+ if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
+ t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
+ }
+}
+
+var roundTripTests = []struct {
+ accept string
+ expectAccept string
+ compressed bool
+}{
+ // Requests with no accept-encoding header use transparent compression
+ {"", "gzip", false},
+ // Requests with other accept-encoding should pass through unmodified
+ {"foo", "foo", false},
+ // Requests with accept-encoding == gzip should be passed through
+ {"gzip", "gzip", true},
+}
+
+// Test that the modification made to the Request by the RoundTripper is cleaned up
+func TestRoundTripGzip(t *testing.T) {
+ const responseBody = "test response body"
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ accept := req.Header.Get("Accept-Encoding")
+ if expect := req.FormValue("expect_accept"); accept != expect {
+ t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
+ req.FormValue("testnum"), accept, expect)
+ }
+ if accept == "gzip" {
+ rw.Header().Set("Content-Encoding", "gzip")
+ gz, _ := gzip.NewWriter(rw)
+ gz.Write([]byte(responseBody))
+ gz.Close()
+ } else {
+ rw.Header().Set("Content-Encoding", accept)
+ rw.Write([]byte(responseBody))
+ }
+ }))
+ defer ts.Close()
+
+ for i, test := range roundTripTests {
+ // Test basic request (no accept-encoding)
+ req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
+ if test.accept != "" {
+ req.Header.Set("Accept-Encoding", test.accept)
+ }
+ res, err := DefaultTransport.RoundTrip(req)
+ var body []byte
+ if test.compressed {
+ gzip, _ := gzip.NewReader(res.Body)
+ body, err = ioutil.ReadAll(gzip)
+ res.Body.Close()
+ } else {
+ body, err = ioutil.ReadAll(res.Body)
+ }
+ if err != nil {
+ t.Errorf("%d. Error: %q", i, err)
+ continue
+ }
+ if g, e := string(body), responseBody; g != e {
+ t.Errorf("%d. body = %q; want %q", i, g, e)
+ }
+ if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
+ }
+ }
+
+}
+
+func TestTransportGzip(t *testing.T) {
+ const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ const nRandBytes = 1024 * 1024
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
+ t.Errorf("Accept-Encoding = %q, want %q", g, e)
+ }
+ rw.Header().Set("Content-Encoding", "gzip")
+ if req.Method == "HEAD" {
+ return
+ }
+
+ var w io.Writer = rw
+ var buf bytes.Buffer
+ if req.FormValue("chunked") == "0" {
+ w = &buf
+ defer io.Copy(rw, &buf)
+ defer func() {
+ rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
+ }()
+ }
+ gz, _ := gzip.NewWriter(w)
+ gz.Write([]byte(testString))
+ if req.FormValue("body") == "large" {
+ io.CopyN(gz, rand.Reader, nRandBytes)
+ }
+ gz.Close()
+ }))
+ defer ts.Close()
+
+ for _, chunked := range []string{"1", "0"} {
+ c := &Client{Transport: &Transport{}}
+
+ // First fetch something large, but only read some of it.
+ res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
+ if err != nil {
+ t.Fatalf("large get: %v", err)
+ }
+ buf := make([]byte, len(testString))
+ n, err := io.ReadFull(res.Body, buf)
+ if err != nil {
+ t.Fatalf("partial read of large response: size=%d, %v", n, err)
+ }
+ if e, g := testString, string(buf); e != g {
+ t.Errorf("partial read got %q, expected %q", g, e)
+ }
+ res.Body.Close()
+ // Read on the body, even though it's closed
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
+ }
+
+ // Then something small.
+ res, err = c.Get(ts.URL + "/?chunked=" + chunked)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := string(body), testString; g != e {
+ t.Fatalf("body = %q; want %q", g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ }
+
+ // Read on the body after it's been fully read:
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
+ }
+ res.Body.Close()
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after Close; got %d, %v", n, err)
+ }
+ }
+
+ // And a HEAD request too, because they're always weird.
+ c := &Client{Transport: &Transport{}}
+ res, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatalf("Head: %v", err)
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("Head status=%d; want=200", res.StatusCode)
+ }
+}
+
+func TestTransportProxy(t *testing.T) {
+ ch := make(chan string, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ch <- "real server"
+ }))
+ defer ts.Close()
+ proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ch <- "proxy for " + r.URL.String()
+ }))
+ defer proxy.Close()
+
+ pu, err := url.Parse(proxy.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}}
+ c.Head(ts.URL)
+ got := <-ch
+ want := "proxy for " + ts.URL + "/"
+ if got != want {
+ t.Errorf("want %q, got %q", want, got)
+ }
+}
+
+// TestTransportGzipRecursive sends a gzip quine and checks that the
+// client gets the same value back. This is more cute than anything,
+// but checks that we don't recurse forever, and checks that
+// Content-Encoding is removed.
+func TestTransportGzipRecursive(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Write(rgz)
+ }))
+ defer ts.Close()
+
+ c := &Client{Transport: &Transport{}}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(body, rgz) {
+ t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
+ body, rgz)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ }
+}
+
+type fooProto struct{}
+
+func (fooProto) RoundTrip(req *Request) (*Response, error) {
+ res := &Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Header: make(Header),
+ Body: ioutil.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
+ }
+ return res, nil
+}
+
+func TestTransportAltProto(t *testing.T) {
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ tr.RegisterProtocol("foo", fooProto{})
+ res, err := c.Get("foo://bar.com/path")
+ if err != nil {
+ t.Fatal(err)
+ }
+ bodyb, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body := string(bodyb)
+ if e := "You wanted foo://bar.com/path"; body != e {
+ t.Errorf("got response %q, want %q", body, e)
+ }
+}
+
+// rgz is a gzip quine that uncompresses to itself.
+var rgz = []byte{
+ 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
+ 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
+ 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
+ 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
+ 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
+ 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
+ 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
+ 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
+ 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
+ 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
+ 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
+ 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
+ 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
+ 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
+ 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
+ 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
+ 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
+ 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
+ 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
+ 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
+ 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
+ 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
+ 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
+ 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
+ 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
+ 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
+ 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
+ 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
+ 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
+ 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
+ 0x00, 0x00,
+}
diff --git a/src/pkg/net/http/triv.go b/src/pkg/net/http/triv.go
new file mode 100644
index 000000000..994fc0e32
--- /dev/null
+++ b/src/pkg/net/http/triv.go
@@ -0,0 +1,149 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bytes"
+ "expvar"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "strconv"
+)
+
+// hello world, the web server
+var helloRequests = expvar.NewInt("hello-requests")
+
+func HelloServer(w http.ResponseWriter, req *http.Request) {
+ helloRequests.Add(1)
+ io.WriteString(w, "hello, world!\n")
+}
+
+// Simple counter server. POSTing to it will set the value.
+type Counter struct {
+ n int
+}
+
+// This makes Counter satisfy the expvar.Var interface, so we can export
+// it directly.
+func (ctr *Counter) String() string { return fmt.Sprintf("%d", ctr.n) }
+
+func (ctr *Counter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ switch req.Method {
+ case "GET":
+ ctr.n++
+ case "POST":
+ buf := new(bytes.Buffer)
+ io.Copy(buf, req.Body)
+ body := buf.String()
+ if n, err := strconv.Atoi(body); err != nil {
+ fmt.Fprintf(w, "bad POST: %v\nbody: [%v]\n", err, body)
+ } else {
+ ctr.n = n
+ fmt.Fprint(w, "counter reset\n")
+ }
+ }
+ fmt.Fprintf(w, "counter = %d\n", ctr.n)
+}
+
+// simple flag server
+var booleanflag = flag.Bool("boolean", true, "another flag for testing")
+
+func FlagServer(w http.ResponseWriter, req *http.Request) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ fmt.Fprint(w, "Flags:\n")
+ flag.VisitAll(func(f *flag.Flag) {
+ if f.Value.String() != f.DefValue {
+ fmt.Fprintf(w, "%s = %s [default = %s]\n", f.Name, f.Value.String(), f.DefValue)
+ } else {
+ fmt.Fprintf(w, "%s = %s\n", f.Name, f.Value.String())
+ }
+ })
+}
+
+// simple argument server
+func ArgServer(w http.ResponseWriter, req *http.Request) {
+ for _, s := range os.Args {
+ fmt.Fprint(w, s, " ")
+ }
+}
+
+// a channel (just for the fun of it)
+type Chan chan int
+
+func ChanCreate() Chan {
+ c := make(Chan)
+ go func(c Chan) {
+ for x := 0; ; x++ {
+ c <- x
+ }
+ }(c)
+ return c
+}
+
+func (ch Chan) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ io.WriteString(w, fmt.Sprintf("channel send #%d\n", <-ch))
+}
+
+// exec a program, redirecting output
+func DateServer(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ r, w, err := os.Pipe()
+ if err != nil {
+ fmt.Fprintf(rw, "pipe: %s\n", err)
+ return
+ }
+
+ p, err := os.StartProcess("/bin/date", []string{"date"}, &os.ProcAttr{Files: []*os.File{nil, w, w}})
+ defer r.Close()
+ w.Close()
+ if err != nil {
+ fmt.Fprintf(rw, "fork/exec: %s\n", err)
+ return
+ }
+ defer p.Release()
+ io.Copy(rw, r)
+ wait, err := p.Wait(0)
+ if err != nil {
+ fmt.Fprintf(rw, "wait: %s\n", err)
+ return
+ }
+ if !wait.Exited() || wait.ExitStatus() != 0 {
+ fmt.Fprintf(rw, "date: %v\n", wait)
+ return
+ }
+}
+
+func Logger(w http.ResponseWriter, req *http.Request) {
+ log.Print(req.URL.Raw)
+ w.WriteHeader(404)
+ w.Write([]byte("oops"))
+}
+
+var webroot = flag.String("root", "/home/rsc", "web root directory")
+
+func main() {
+ flag.Parse()
+
+ // The counter is published as a variable directly.
+ ctr := new(Counter)
+ http.Handle("/counter", ctr)
+ expvar.Publish("counter", ctr)
+
+ http.Handle("/", http.HandlerFunc(Logger))
+ http.Handle("/go/", http.StripPrefix("/go/", http.FileServer(http.Dir(*webroot))))
+ http.Handle("/flags", http.HandlerFunc(FlagServer))
+ http.Handle("/args", http.HandlerFunc(ArgServer))
+ http.Handle("/go/hello", http.HandlerFunc(HelloServer))
+ http.Handle("/chan", ChanCreate())
+ http.Handle("/date", http.HandlerFunc(DateServer))
+ err := http.ListenAndServe(":12345", nil)
+ if err != nil {
+ log.Panicln("ListenAndServe:", err)
+ }
+}
diff --git a/src/pkg/net/interface.go b/src/pkg/net/interface.go
index 2696b7f4c..5e7b352ed 100644
--- a/src/pkg/net/interface.go
+++ b/src/pkg/net/interface.go
@@ -8,8 +8,16 @@ package net
import (
"bytes"
+ "errors"
"fmt"
- "os"
+)
+
+var (
+ errInvalidInterface = errors.New("net: invalid interface")
+ errInvalidInterfaceIndex = errors.New("net: invalid interface index")
+ errInvalidInterfaceName = errors.New("net: invalid interface name")
+ errNoSuchInterface = errors.New("net: no such interface")
+ errNoSuchMulticastInterface = errors.New("net: no such multicast interface")
)
// A HardwareAddr represents a physical hardware address.
@@ -34,7 +42,7 @@ func (a HardwareAddr) String() string {
// 01-23-45-67-89-ab-cd-ef
// 0123.4567.89ab
// 0123.4567.89ab.cdef
-func ParseMAC(s string) (hw HardwareAddr, err os.Error) {
+func ParseMAC(s string) (hw HardwareAddr, err error) {
if len(s) < 14 {
goto error
}
@@ -80,7 +88,7 @@ func ParseMAC(s string) (hw HardwareAddr, err os.Error) {
return hw, nil
error:
- return nil, os.NewError("invalid MAC address: " + s)
+ return nil, errors.New("invalid MAC address: " + s)
}
// Interface represents a mapping between network interface name
@@ -129,37 +137,37 @@ func (f Flags) String() string {
}
// Addrs returns interface addresses for a specific interface.
-func (ifi *Interface) Addrs() ([]Addr, os.Error) {
+func (ifi *Interface) Addrs() ([]Addr, error) {
if ifi == nil {
- return nil, os.NewError("net: invalid interface")
+ return nil, errInvalidInterface
}
return interfaceAddrTable(ifi.Index)
}
// MulticastAddrs returns multicast, joined group addresses for
// a specific interface.
-func (ifi *Interface) MulticastAddrs() ([]Addr, os.Error) {
+func (ifi *Interface) MulticastAddrs() ([]Addr, error) {
if ifi == nil {
- return nil, os.NewError("net: invalid interface")
+ return nil, errInvalidInterface
}
return interfaceMulticastAddrTable(ifi.Index)
}
// Interfaces returns a list of the systems's network interfaces.
-func Interfaces() ([]Interface, os.Error) {
+func Interfaces() ([]Interface, error) {
return interfaceTable(0)
}
// InterfaceAddrs returns a list of the system's network interface
// addresses.
-func InterfaceAddrs() ([]Addr, os.Error) {
+func InterfaceAddrs() ([]Addr, error) {
return interfaceAddrTable(0)
}
// InterfaceByIndex returns the interface specified by index.
-func InterfaceByIndex(index int) (*Interface, os.Error) {
+func InterfaceByIndex(index int) (*Interface, error) {
if index <= 0 {
- return nil, os.NewError("net: invalid interface index")
+ return nil, errInvalidInterfaceIndex
}
ift, err := interfaceTable(index)
if err != nil {
@@ -168,13 +176,13 @@ func InterfaceByIndex(index int) (*Interface, os.Error) {
for _, ifi := range ift {
return &ifi, nil
}
- return nil, os.NewError("net: no such interface")
+ return nil, errNoSuchInterface
}
// InterfaceByName returns the interface specified by name.
-func InterfaceByName(name string) (*Interface, os.Error) {
+func InterfaceByName(name string) (*Interface, error) {
if name == "" {
- return nil, os.NewError("net: invalid interface name")
+ return nil, errInvalidInterfaceName
}
ift, err := interfaceTable(0)
if err != nil {
@@ -185,5 +193,5 @@ func InterfaceByName(name string) (*Interface, os.Error) {
return &ifi, nil
}
}
- return nil, os.NewError("net: no such interface")
+ return nil, errNoSuchInterface
}
diff --git a/src/pkg/net/interface_bsd.go b/src/pkg/net/interface_bsd.go
index 9171827d2..907f80a80 100644
--- a/src/pkg/net/interface_bsd.go
+++ b/src/pkg/net/interface_bsd.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd openbsd
+// +build darwin freebsd netbsd openbsd
// Network interface identification for BSD variants
@@ -17,22 +17,17 @@ import (
// If the ifindex is zero, interfaceTable returns mappings of all
// network interfaces. Otheriwse it returns a mapping of a specific
// interface.
-func interfaceTable(ifindex int) ([]Interface, os.Error) {
- var (
- tab []byte
- e int
- msgs []syscall.RoutingMessage
- ift []Interface
- )
-
- tab, e = syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex)
- if e != 0 {
- return nil, os.NewSyscallError("route rib", e)
+func interfaceTable(ifindex int) ([]Interface, error) {
+ var ift []Interface
+
+ tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex)
+ if err != nil {
+ return nil, os.NewSyscallError("route rib", err)
}
- msgs, e = syscall.ParseRoutingMessage(tab)
- if e != 0 {
- return nil, os.NewSyscallError("route message", e)
+ msgs, err := syscall.ParseRoutingMessage(tab)
+ if err != nil {
+ return nil, os.NewSyscallError("route message", err)
}
for _, m := range msgs {
@@ -51,12 +46,12 @@ func interfaceTable(ifindex int) ([]Interface, os.Error) {
return ift, nil
}
-func newLink(m *syscall.InterfaceMessage) ([]Interface, os.Error) {
+func newLink(m *syscall.InterfaceMessage) ([]Interface, error) {
var ift []Interface
- sas, e := syscall.ParseRoutingSockaddr(m)
- if e != 0 {
- return nil, os.NewSyscallError("route sockaddr", e)
+ sas, err := syscall.ParseRoutingSockaddr(m)
+ if err != nil {
+ return nil, os.NewSyscallError("route sockaddr", err)
}
for _, s := range sas {
@@ -107,22 +102,17 @@ func linkFlags(rawFlags int32) Flags {
// If the ifindex is zero, interfaceAddrTable returns addresses
// for all network interfaces. Otherwise it returns addresses
// for a specific interface.
-func interfaceAddrTable(ifindex int) ([]Addr, os.Error) {
- var (
- tab []byte
- e int
- msgs []syscall.RoutingMessage
- ifat []Addr
- )
-
- tab, e = syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex)
- if e != 0 {
- return nil, os.NewSyscallError("route rib", e)
+func interfaceAddrTable(ifindex int) ([]Addr, error) {
+ var ifat []Addr
+
+ tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex)
+ if err != nil {
+ return nil, os.NewSyscallError("route rib", err)
}
- msgs, e = syscall.ParseRoutingMessage(tab)
- if e != 0 {
- return nil, os.NewSyscallError("route message", e)
+ msgs, err := syscall.ParseRoutingMessage(tab)
+ if err != nil {
+ return nil, os.NewSyscallError("route message", err)
}
for _, m := range msgs {
@@ -133,7 +123,7 @@ func interfaceAddrTable(ifindex int) ([]Addr, os.Error) {
if err != nil {
return nil, err
}
- ifat = append(ifat, ifa...)
+ ifat = append(ifat, ifa)
}
}
}
@@ -141,33 +131,41 @@ func interfaceAddrTable(ifindex int) ([]Addr, os.Error) {
return ifat, nil
}
-func newAddr(m *syscall.InterfaceAddrMessage) ([]Addr, os.Error) {
- var ifat []Addr
+func newAddr(m *syscall.InterfaceAddrMessage) (Addr, error) {
+ ifa := &IPNet{}
- sas, e := syscall.ParseRoutingSockaddr(m)
- if e != 0 {
- return nil, os.NewSyscallError("route sockaddr", e)
+ sas, err := syscall.ParseRoutingSockaddr(m)
+ if err != nil {
+ return nil, os.NewSyscallError("route sockaddr", err)
}
- for _, s := range sas {
-
+ for i, s := range sas {
switch v := s.(type) {
case *syscall.SockaddrInet4:
- ifa := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])}
- ifat = append(ifat, ifa.toAddr())
+ switch i {
+ case 0:
+ ifa.Mask = IPv4Mask(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])
+ case 1:
+ ifa.IP = IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])
+ }
case *syscall.SockaddrInet6:
- ifa := &IPAddr{IP: make(IP, IPv6len)}
- copy(ifa.IP, v.Addr[:])
- // NOTE: KAME based IPv6 protcol stack usually embeds
- // the interface index in the interface-local or link-
- // local address as the kernel-internal form.
- if ifa.IP.IsLinkLocalUnicast() {
- // remove embedded scope zone ID
- ifa.IP[2], ifa.IP[3] = 0, 0
+ switch i {
+ case 0:
+ ifa.Mask = make(IPMask, IPv6len)
+ copy(ifa.Mask, v.Addr[:])
+ case 1:
+ ifa.IP = make(IP, IPv6len)
+ copy(ifa.IP, v.Addr[:])
+ // NOTE: KAME based IPv6 protcol stack usually embeds
+ // the interface index in the interface-local or link-
+ // local address as the kernel-internal form.
+ if ifa.IP.IsLinkLocalUnicast() {
+ // remove embedded scope zone ID
+ ifa.IP[2], ifa.IP[3] = 0, 0
+ }
}
- ifat = append(ifat, ifa.toAddr())
}
}
- return ifat, nil
+ return ifa, nil
}
diff --git a/src/pkg/net/interface_darwin.go b/src/pkg/net/interface_darwin.go
index a7b68ad7f..2da447adc 100644
--- a/src/pkg/net/interface_darwin.go
+++ b/src/pkg/net/interface_darwin.go
@@ -14,21 +14,21 @@ import (
// If the ifindex is zero, interfaceMulticastAddrTable returns
// addresses for all network interfaces. Otherwise it returns
// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) {
+func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
var (
tab []byte
- e int
+ e error
msgs []syscall.RoutingMessage
ifmat []Addr
)
tab, e = syscall.RouteRIB(syscall.NET_RT_IFLIST2, ifindex)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("route rib", e)
}
msgs, e = syscall.ParseRoutingMessage(tab)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("route message", e)
}
@@ -48,11 +48,11 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) {
return ifmat, nil
}
-func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, os.Error) {
+func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) {
var ifmat []Addr
sas, e := syscall.ParseRoutingSockaddr(m)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("route sockaddr", e)
}
diff --git a/src/pkg/net/interface_freebsd.go b/src/pkg/net/interface_freebsd.go
index 20f506b08..a12877e25 100644
--- a/src/pkg/net/interface_freebsd.go
+++ b/src/pkg/net/interface_freebsd.go
@@ -14,21 +14,21 @@ import (
// If the ifindex is zero, interfaceMulticastAddrTable returns
// addresses for all network interfaces. Otherwise it returns
// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) {
+func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
var (
tab []byte
- e int
+ e error
msgs []syscall.RoutingMessage
ifmat []Addr
)
tab, e = syscall.RouteRIB(syscall.NET_RT_IFMALIST, ifindex)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("route rib", e)
}
msgs, e = syscall.ParseRoutingMessage(tab)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("route message", e)
}
@@ -48,11 +48,11 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) {
return ifmat, nil
}
-func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, os.Error) {
+func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) {
var ifmat []Addr
sas, e := syscall.ParseRoutingSockaddr(m)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("route sockaddr", e)
}
diff --git a/src/pkg/net/interface_linux.go b/src/pkg/net/interface_linux.go
index 3d2a0bb9f..c0887c57e 100644
--- a/src/pkg/net/interface_linux.go
+++ b/src/pkg/net/interface_linux.go
@@ -16,22 +16,17 @@ import (
// If the ifindex is zero, interfaceTable returns mappings of all
// network interfaces. Otheriwse it returns a mapping of a specific
// interface.
-func interfaceTable(ifindex int) ([]Interface, os.Error) {
- var (
- ift []Interface
- tab []byte
- msgs []syscall.NetlinkMessage
- e int
- )
+func interfaceTable(ifindex int) ([]Interface, error) {
+ var ift []Interface
- tab, e = syscall.NetlinkRIB(syscall.RTM_GETLINK, syscall.AF_UNSPEC)
- if e != 0 {
- return nil, os.NewSyscallError("netlink rib", e)
+ tab, err := syscall.NetlinkRIB(syscall.RTM_GETLINK, syscall.AF_UNSPEC)
+ if err != nil {
+ return nil, os.NewSyscallError("netlink rib", err)
}
- msgs, e = syscall.ParseNetlinkMessage(tab)
- if e != 0 {
- return nil, os.NewSyscallError("netlink message", e)
+ msgs, err := syscall.ParseNetlinkMessage(tab)
+ if err != nil {
+ return nil, os.NewSyscallError("netlink message", err)
}
for _, m := range msgs {
@@ -41,11 +36,11 @@ func interfaceTable(ifindex int) ([]Interface, os.Error) {
case syscall.RTM_NEWLINK:
ifim := (*syscall.IfInfomsg)(unsafe.Pointer(&m.Data[0]))
if ifindex == 0 || ifindex == int(ifim.Index) {
- attrs, e := syscall.ParseNetlinkRouteAttr(&m)
- if e != 0 {
- return nil, os.NewSyscallError("netlink routeattr", e)
+ attrs, err := syscall.ParseNetlinkRouteAttr(&m)
+ if err != nil {
+ return nil, os.NewSyscallError("netlink routeattr", err)
}
- ifi := newLink(attrs, ifim)
+ ifi := newLink(ifim, attrs)
ift = append(ift, ifi)
}
}
@@ -55,7 +50,7 @@ done:
return ift, nil
}
-func newLink(attrs []syscall.NetlinkRouteAttr, ifim *syscall.IfInfomsg) Interface {
+func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) Interface {
ifi := Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)}
for _, a := range attrs {
switch a.Attr.Type {
@@ -101,47 +96,26 @@ func linkFlags(rawFlags uint32) Flags {
// If the ifindex is zero, interfaceAddrTable returns addresses
// for all network interfaces. Otherwise it returns addresses
// for a specific interface.
-func interfaceAddrTable(ifindex int) ([]Addr, os.Error) {
- var (
- tab []byte
- e int
- err os.Error
- ifat4 []Addr
- ifat6 []Addr
- msgs4 []syscall.NetlinkMessage
- msgs6 []syscall.NetlinkMessage
- )
-
- tab, e = syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_INET)
- if e != 0 {
- return nil, os.NewSyscallError("netlink rib", e)
- }
- msgs4, e = syscall.ParseNetlinkMessage(tab)
- if e != 0 {
- return nil, os.NewSyscallError("netlink message", e)
- }
- ifat4, err = addrTable(msgs4, ifindex)
+func interfaceAddrTable(ifindex int) ([]Addr, error) {
+ tab, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC)
if err != nil {
- return nil, err
+ return nil, os.NewSyscallError("netlink rib", err)
}
- tab, e = syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_INET6)
- if e != 0 {
- return nil, os.NewSyscallError("netlink rib", e)
- }
- msgs6, e = syscall.ParseNetlinkMessage(tab)
- if e != 0 {
- return nil, os.NewSyscallError("netlink message", e)
+ msgs, err := syscall.ParseNetlinkMessage(tab)
+ if err != nil {
+ return nil, os.NewSyscallError("netlink message", err)
}
- ifat6, err = addrTable(msgs6, ifindex)
+
+ ifat, err := addrTable(msgs, ifindex)
if err != nil {
return nil, err
}
- return append(ifat4, ifat6...), nil
+ return ifat, nil
}
-func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, os.Error) {
+func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, error) {
var ifat []Addr
for _, m := range msgs {
@@ -151,11 +125,11 @@ func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, os.Error) {
case syscall.RTM_NEWADDR:
ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0]))
if ifindex == 0 || ifindex == int(ifam.Index) {
- attrs, e := syscall.ParseNetlinkRouteAttr(&m)
- if e != 0 {
- return nil, os.NewSyscallError("netlink routeattr", e)
+ attrs, err := syscall.ParseNetlinkRouteAttr(&m)
+ if err != nil {
+ return nil, os.NewSyscallError("netlink routeattr", err)
}
- ifat = append(ifat, newAddr(attrs, int(ifam.Family))...)
+ ifat = append(ifat, newAddr(attrs, int(ifam.Family), int(ifam.Prefixlen)))
}
}
}
@@ -164,34 +138,32 @@ done:
return ifat, nil
}
-func newAddr(attrs []syscall.NetlinkRouteAttr, family int) []Addr {
- var ifat []Addr
-
+func newAddr(attrs []syscall.NetlinkRouteAttr, family, pfxlen int) Addr {
+ ifa := &IPNet{}
for _, a := range attrs {
switch a.Attr.Type {
case syscall.IFA_ADDRESS:
switch family {
case syscall.AF_INET:
- ifa := &IPAddr{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3])}
- ifat = append(ifat, ifa.toAddr())
+ ifa.IP = IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3])
+ ifa.Mask = CIDRMask(pfxlen, 8*IPv4len)
case syscall.AF_INET6:
- ifa := &IPAddr{IP: make(IP, IPv6len)}
+ ifa.IP = make(IP, IPv6len)
copy(ifa.IP, a.Value[:])
- ifat = append(ifat, ifa.toAddr())
+ ifa.Mask = CIDRMask(pfxlen, 8*IPv6len)
}
}
}
-
- return ifat
+ return ifa
}
// If the ifindex is zero, interfaceMulticastAddrTable returns
// addresses for all network interfaces. Otherwise it returns
// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) {
+func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
var (
+ err error
ifi *Interface
- err os.Error
)
if ifindex > 0 {
diff --git a/src/pkg/net/interface_netbsd.go b/src/pkg/net/interface_netbsd.go
new file mode 100644
index 000000000..4150e9ad5
--- /dev/null
+++ b/src/pkg/net/interface_netbsd.go
@@ -0,0 +1,14 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Network interface identification for NetBSD
+
+package net
+
+// If the ifindex is zero, interfaceMulticastAddrTable returns
+// addresses for all network interfaces. Otherwise it returns
+// addresses for a specific interface.
+func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
+ return nil, nil
+}
diff --git a/src/pkg/net/interface_openbsd.go b/src/pkg/net/interface_openbsd.go
index f18149393..d8adb4676 100644
--- a/src/pkg/net/interface_openbsd.go
+++ b/src/pkg/net/interface_openbsd.go
@@ -6,11 +6,9 @@
package net
-import "os"
-
// If the ifindex is zero, interfaceMulticastAddrTable returns
// addresses for all network interfaces. Otherwise it returns
// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) {
+func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
return nil, nil
}
diff --git a/src/pkg/net/interface_stub.go b/src/pkg/net/interface_stub.go
index 282b38b5e..4876b3af3 100644
--- a/src/pkg/net/interface_stub.go
+++ b/src/pkg/net/interface_stub.go
@@ -8,25 +8,23 @@
package net
-import "os"
-
// If the ifindex is zero, interfaceTable returns mappings of all
// network interfaces. Otheriwse it returns a mapping of a specific
// interface.
-func interfaceTable(ifindex int) ([]Interface, os.Error) {
+func interfaceTable(ifindex int) ([]Interface, error) {
return nil, nil
}
// If the ifindex is zero, interfaceAddrTable returns addresses
// for all network interfaces. Otherwise it returns addresses
// for a specific interface.
-func interfaceAddrTable(ifindex int) ([]Addr, os.Error) {
+func interfaceAddrTable(ifindex int) ([]Addr, error) {
return nil, nil
}
// If the ifindex is zero, interfaceMulticastAddrTable returns
// addresses for all network interfaces. Otherwise it returns
// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) {
+func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
return nil, nil
}
diff --git a/src/pkg/net/interface_test.go b/src/pkg/net/interface_test.go
index c918f247f..4ce01dc90 100644
--- a/src/pkg/net/interface_test.go
+++ b/src/pkg/net/interface_test.go
@@ -6,7 +6,6 @@ package net
import (
"bytes"
- "os"
"reflect"
"strings"
"testing"
@@ -25,7 +24,7 @@ func sameInterface(i, j *Interface) bool {
func TestInterfaces(t *testing.T) {
ift, err := Interfaces()
if err != nil {
- t.Fatalf("Interfaces() failed: %v", err)
+ t.Fatalf("Interfaces failed: %v", err)
}
t.Logf("table: len/cap = %v/%v\n", len(ift), cap(ift))
@@ -44,34 +43,57 @@ func TestInterfaces(t *testing.T) {
if !sameInterface(ifxn, &ifi) {
t.Fatalf("InterfaceByName(%#q) = %v, want %v", ifi.Name, *ifxn, ifi)
}
- ifat, err := ifi.Addrs()
- if err != nil {
- t.Fatalf("Interface.Addrs() failed: %v", err)
- }
- ifmat, err := ifi.MulticastAddrs()
- if err != nil {
- t.Fatalf("Interface.MulticastAddrs() failed: %v", err)
- }
t.Logf("%q: flags %q, ifindex %v, mtu %v\n", ifi.Name, ifi.Flags.String(), ifi.Index, ifi.MTU)
- for _, ifa := range ifat {
- t.Logf("\tinterface address %q\n", ifa.String())
- }
- for _, ifma := range ifmat {
- t.Logf("\tjoined group address %q\n", ifma.String())
- }
t.Logf("\thardware address %q", ifi.HardwareAddr.String())
+ testInterfaceAddrs(t, &ifi)
+ testInterfaceMulticastAddrs(t, &ifi)
}
}
func TestInterfaceAddrs(t *testing.T) {
ifat, err := InterfaceAddrs()
if err != nil {
- t.Fatalf("InterfaceAddrs() failed: %v", err)
+ t.Fatalf("InterfaceAddrs failed: %v", err)
}
t.Logf("table: len/cap = %v/%v\n", len(ifat), cap(ifat))
+ testAddrs(t, ifat)
+}
+
+func testInterfaceAddrs(t *testing.T, ifi *Interface) {
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ t.Fatalf("Interface.Addrs failed: %v", err)
+ }
+ testAddrs(t, ifat)
+}
+
+func testInterfaceMulticastAddrs(t *testing.T, ifi *Interface) {
+ ifmat, err := ifi.MulticastAddrs()
+ if err != nil {
+ t.Fatalf("Interface.MulticastAddrs failed: %v", err)
+ }
+ testMulticastAddrs(t, ifmat)
+}
+func testAddrs(t *testing.T, ifat []Addr) {
for _, ifa := range ifat {
- t.Logf("interface address %q\n", ifa.String())
+ switch ifa.(type) {
+ case *IPAddr, *IPNet:
+ t.Logf("\tinterface address %q\n", ifa.String())
+ default:
+ t.Errorf("\tunexpected type: %T", ifa)
+ }
+ }
+}
+
+func testMulticastAddrs(t *testing.T, ifmat []Addr) {
+ for _, ifma := range ifmat {
+ switch ifma.(type) {
+ case *IPAddr:
+ t.Logf("\tjoined group address %q\n", ifma.String())
+ default:
+ t.Errorf("\tunexpected type: %T", ifma)
+ }
}
}
@@ -101,11 +123,11 @@ var mactests = []struct {
{"0123.4567.89AB.CDEF", HardwareAddr{1, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, ""},
}
-func match(err os.Error, s string) bool {
+func match(err error, s string) bool {
if s == "" {
return err == nil
}
- return err != nil && strings.Contains(err.String(), s)
+ return err != nil && strings.Contains(err.Error(), s)
}
func TestParseMAC(t *testing.T) {
diff --git a/src/pkg/net/interface_windows.go b/src/pkg/net/interface_windows.go
index 7f5169c87..add3dd3b9 100644
--- a/src/pkg/net/interface_windows.go
+++ b/src/pkg/net/interface_windows.go
@@ -21,7 +21,7 @@ func bytePtrToString(p *uint8) string {
return string(a[:i])
}
-func getAdapterList() (*syscall.IpAdapterInfo, os.Error) {
+func getAdapterList() (*syscall.IpAdapterInfo, error) {
b := make([]byte, 1000)
l := uint32(len(b))
a := (*syscall.IpAdapterInfo)(unsafe.Pointer(&b[0]))
@@ -31,15 +31,15 @@ func getAdapterList() (*syscall.IpAdapterInfo, os.Error) {
a = (*syscall.IpAdapterInfo)(unsafe.Pointer(&b[0]))
e = syscall.GetAdaptersInfo(a, &l)
}
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("GetAdaptersInfo", e)
}
return a, nil
}
-func getInterfaceList() ([]syscall.InterfaceInfo, os.Error) {
+func getInterfaceList() ([]syscall.InterfaceInfo, error) {
s, e := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("Socket", e)
}
defer syscall.Closesocket(s)
@@ -48,7 +48,7 @@ func getInterfaceList() ([]syscall.InterfaceInfo, os.Error) {
ret := uint32(0)
size := uint32(unsafe.Sizeof(ii))
e = syscall.WSAIoctl(s, syscall.SIO_GET_INTERFACE_LIST, nil, 0, (*byte)(unsafe.Pointer(&ii[0])), size, &ret, nil, 0)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("WSAIoctl", e)
}
c := ret / uint32(unsafe.Sizeof(ii[0]))
@@ -58,7 +58,7 @@ func getInterfaceList() ([]syscall.InterfaceInfo, os.Error) {
// If the ifindex is zero, interfaceTable returns mappings of all
// network interfaces. Otheriwse it returns a mapping of a specific
// interface.
-func interfaceTable(ifindex int) ([]Interface, os.Error) {
+func interfaceTable(ifindex int) ([]Interface, error) {
ai, e := getAdapterList()
if e != nil {
return nil, e
@@ -77,7 +77,7 @@ func interfaceTable(ifindex int) ([]Interface, os.Error) {
row := syscall.MibIfRow{Index: index}
e := syscall.GetIfEntry(&row)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("GetIfEntry", e)
}
@@ -129,7 +129,7 @@ func interfaceTable(ifindex int) ([]Interface, os.Error) {
// If the ifindex is zero, interfaceAddrTable returns addresses
// for all network interfaces. Otherwise it returns addresses
// for a specific interface.
-func interfaceAddrTable(ifindex int) ([]Addr, os.Error) {
+func interfaceAddrTable(ifindex int) ([]Addr, error) {
ai, e := getAdapterList()
if e != nil {
return nil, e
@@ -153,6 +153,6 @@ func interfaceAddrTable(ifindex int) ([]Addr, os.Error) {
// If the ifindex is zero, interfaceMulticastAddrTable returns
// addresses for all network interfaces. Otherwise it returns
// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) {
+func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
return nil, nil
}
diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go
index 61dc3be90..979d7acd5 100644
--- a/src/pkg/net/ip.go
+++ b/src/pkg/net/ip.go
@@ -12,8 +12,6 @@
package net
-import "os"
-
// IP address lengths (bytes).
const (
IPv4len = 4
@@ -452,6 +450,9 @@ func (n *IPNet) String() string {
return nn.String() + "/" + itod(uint(l))
}
+// Network returns the address's network name, "ip+net".
+func (n *IPNet) Network() string { return "ip+net" }
+
// Parse IPv4 address (d.d.d.d).
func parseIPv4(s string) IP {
var p [IPv4len]byte
@@ -594,7 +595,7 @@ type ParseError struct {
Text string
}
-func (e *ParseError) String() string {
+func (e *ParseError) Error() string {
return "invalid " + e.Type + ": " + e.Text
}
@@ -627,7 +628,7 @@ func ParseIP(s string) IP {
// It returns the IP address and the network implied by the IP
// and mask. For example, ParseCIDR("192.168.100.1/16") returns
// the IP address 192.168.100.1 and the network 192.168.0.0/16.
-func ParseCIDR(s string) (IP, *IPNet, os.Error) {
+func ParseCIDR(s string) (IP, *IPNet, error) {
i := byteIndex(s, '/')
if i < 0 {
return nil, nil, &ParseError{"CIDR address", s}
diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go
index 07e627aef..df647ef73 100644
--- a/src/pkg/net/ip_test.go
+++ b/src/pkg/net/ip_test.go
@@ -7,9 +7,8 @@ package net
import (
"bytes"
"reflect"
- "testing"
- "os"
"runtime"
+ "testing"
)
func isEqual(a, b []byte) bool {
@@ -113,7 +112,7 @@ var parsecidrtests = []struct {
in string
ip IP
net *IPNet
- err os.Error
+ err error
}{
{"135.104.0.0/32", IPv4(135, 104, 0, 0), &IPNet{IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 255)}, nil},
{"0.0.0.0/24", IPv4(0, 0, 0, 0), &IPNet{IPv4(0, 0, 0, 0), IPv4Mask(255, 255, 255, 0)}, nil},
diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go
index 6894ce656..f9401c110 100644
--- a/src/pkg/net/ipraw_test.go
+++ b/src/pkg/net/ipraw_test.go
@@ -2,119 +2,191 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// TODO(cw): ListenPacket test, Read() test, ipv6 test &
-// Dial()/Listen() level tests
-
package net
import (
"bytes"
- "flag"
"os"
"testing"
+ "time"
)
-const ICMP_ECHO_REQUEST = 8
-const ICMP_ECHO_REPLY = 0
-
-// returns a suitable 'ping request' packet, with id & seq and a
-// payload length of pktlen
-func makePingRequest(id, seq, pktlen int, filler []byte) []byte {
- p := make([]byte, pktlen)
- copy(p[8:], bytes.Repeat(filler, (pktlen-8)/len(filler)+1))
-
- p[0] = ICMP_ECHO_REQUEST // type
- p[1] = 0 // code
- p[2] = 0 // cksum
- p[3] = 0 // cksum
- p[4] = uint8(id >> 8) // id
- p[5] = uint8(id & 0xff) // id
- p[6] = uint8(seq >> 8) // sequence
- p[7] = uint8(seq & 0xff) // sequence
-
- // calculate icmp checksum
- cklen := len(p)
- s := uint32(0)
- for i := 0; i < (cklen - 1); i += 2 {
- s += uint32(p[i+1])<<8 | uint32(p[i])
- }
- if cklen&1 == 1 {
- s += uint32(p[cklen-1])
- }
- s = (s >> 16) + (s & 0xffff)
- s = s + (s >> 16)
-
- // place checksum back in header; using ^= avoids the
- // assumption the checksum bytes are zero
- p[2] ^= uint8(^s & 0xff)
- p[3] ^= uint8(^s >> 8)
-
- return p
-}
-
-func parsePingReply(p []byte) (id, seq int) {
- id = int(p[4])<<8 | int(p[5])
- seq = int(p[6])<<8 | int(p[7])
- return
+var icmpTests = []struct {
+ net string
+ laddr string
+ raddr string
+ ipv6 bool
+}{
+ {"ip4:icmp", "", "127.0.0.1", false},
+ {"ip6:icmp", "", "::1", true},
}
-var srchost = flag.String("srchost", "", "Source of the ICMP ECHO request")
-// 127.0.0.1 because this is an IPv4-specific test.
-var dsthost = flag.String("dsthost", "127.0.0.1", "Destination for the ICMP ECHO request")
-
-// test (raw) IP socket using ICMP
func TestICMP(t *testing.T) {
if os.Getuid() != 0 {
t.Logf("test disabled; must be root")
return
}
- var (
- laddr *IPAddr
- err os.Error
- )
- if *srchost != "" {
- laddr, err = ResolveIPAddr("ip4", *srchost)
- if err != nil {
- t.Fatalf(`net.ResolveIPAddr("ip4", %v") = %v, %v`, *srchost, laddr, err)
+ seqnum := 61455
+ for _, tt := range icmpTests {
+ if tt.ipv6 && !supportsIPv6 {
+ continue
}
+ id := os.Getpid() & 0xffff
+ seqnum++
+ echo := newICMPEchoRequest(tt.ipv6, id, seqnum, 128, []byte("Go Go Gadget Ping!!!"))
+ exchangeICMPEcho(t, tt.net, tt.laddr, tt.raddr, tt.ipv6, echo)
}
+}
- raddr, err := ResolveIPAddr("ip4", *dsthost)
+func exchangeICMPEcho(t *testing.T, net, laddr, raddr string, ipv6 bool, echo []byte) {
+ c, err := ListenPacket(net, laddr)
if err != nil {
- t.Fatalf(`net.ResolveIPAddr("ip4", %v") = %v, %v`, *dsthost, raddr, err)
+ t.Errorf("ListenPacket(%#q, %#q) failed: %v", net, laddr, err)
+ return
}
+ c.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ defer c.Close()
- c, err := ListenIP("ip4:icmp", laddr)
+ ra, err := ResolveIPAddr(net, raddr)
if err != nil {
- t.Fatalf(`net.ListenIP("ip4:icmp", %v) = %v, %v`, *srchost, c, err)
+ t.Errorf("ResolveIPAddr(%#q, %#q) failed: %v", net, raddr, err)
+ return
}
- sendid := os.Getpid() & 0xffff
- const sendseq = 61455
- const pingpktlen = 128
- sendpkt := makePingRequest(sendid, sendseq, pingpktlen, []byte("Go Go Gadget Ping!!!"))
+ waitForReady := make(chan bool)
+ go icmpEchoTransponder(t, net, raddr, ipv6, waitForReady)
+ <-waitForReady
- n, err := c.WriteToIP(sendpkt, raddr)
- if err != nil || n != pingpktlen {
- t.Fatalf(`net.WriteToIP(..., %v) = %v, %v`, raddr, n, err)
+ _, err = c.WriteTo(echo, ra)
+ if err != nil {
+ t.Errorf("WriteTo failed: %v", err)
+ return
}
- c.SetTimeout(100e6)
- resp := make([]byte, 1024)
+ reply := make([]byte, 256)
for {
- n, from, err := c.ReadFrom(resp)
+ _, _, err := c.ReadFrom(reply)
if err != nil {
- t.Fatalf(`ReadFrom(...) = %v, %v, %v`, n, from, err)
+ t.Errorf("ReadFrom failed: %v", err)
+ return
}
- if resp[0] != ICMP_ECHO_REPLY {
+ if !ipv6 && reply[0] != ICMP4_ECHO_REPLY {
continue
}
- rcvid, rcvseq := parsePingReply(resp)
- if rcvid != sendid || rcvseq != sendseq {
- t.Fatalf(`Ping reply saw id,seq=0x%x,0x%x (expected 0x%x, 0x%x)`, rcvid, rcvseq, sendid, sendseq)
+ if ipv6 && reply[0] != ICMP6_ECHO_REPLY {
+ continue
+ }
+ xid, xseqnum := parseICMPEchoReply(echo)
+ rid, rseqnum := parseICMPEchoReply(reply)
+ if rid != xid || rseqnum != xseqnum {
+ t.Errorf("ID = %v, Seqnum = %v, want ID = %v, Seqnum = %v", rid, rseqnum, xid, xseqnum)
+ return
}
+ break
+ }
+}
+
+func icmpEchoTransponder(t *testing.T, net, raddr string, ipv6 bool, waitForReady chan bool) {
+ c, err := Dial(net, raddr)
+ if err != nil {
+ waitForReady <- true
+ t.Errorf("Dial(%#q, %#q) failed: %v", net, raddr, err)
return
}
- t.Fatalf("saw no ping return")
+ c.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ defer c.Close()
+ waitForReady <- true
+
+ echo := make([]byte, 256)
+ var nr int
+ for {
+ nr, err = c.Read(echo)
+ if err != nil {
+ t.Errorf("Read failed: %v", err)
+ return
+ }
+ if !ipv6 && echo[0] != ICMP4_ECHO_REQUEST {
+ continue
+ }
+ if ipv6 && echo[0] != ICMP6_ECHO_REQUEST {
+ continue
+ }
+ break
+ }
+
+ if !ipv6 {
+ echo[0] = ICMP4_ECHO_REPLY
+ } else {
+ echo[0] = ICMP6_ECHO_REPLY
+ }
+
+ _, err = c.Write(echo[:nr])
+ if err != nil {
+ t.Errorf("Write failed: %v", err)
+ return
+ }
+}
+
+const (
+ ICMP4_ECHO_REQUEST = 8
+ ICMP4_ECHO_REPLY = 0
+ ICMP6_ECHO_REQUEST = 128
+ ICMP6_ECHO_REPLY = 129
+)
+
+func newICMPEchoRequest(ipv6 bool, id, seqnum, msglen int, filler []byte) []byte {
+ if !ipv6 {
+ return newICMPv4EchoRequest(id, seqnum, msglen, filler)
+ }
+ return newICMPv6EchoRequest(id, seqnum, msglen, filler)
+}
+
+func newICMPv4EchoRequest(id, seqnum, msglen int, filler []byte) []byte {
+ b := newICMPInfoMessage(id, seqnum, msglen, filler)
+ b[0] = ICMP4_ECHO_REQUEST
+
+ // calculate ICMP checksum
+ cklen := len(b)
+ s := uint32(0)
+ for i := 0; i < cklen-1; i += 2 {
+ s += uint32(b[i+1])<<8 | uint32(b[i])
+ }
+ if cklen&1 == 1 {
+ s += uint32(b[cklen-1])
+ }
+ s = (s >> 16) + (s & 0xffff)
+ s = s + (s >> 16)
+ // place checksum back in header; using ^= avoids the
+ // assumption the checksum bytes are zero
+ b[2] ^= uint8(^s & 0xff)
+ b[3] ^= uint8(^s >> 8)
+
+ return b
+}
+
+func newICMPv6EchoRequest(id, seqnum, msglen int, filler []byte) []byte {
+ b := newICMPInfoMessage(id, seqnum, msglen, filler)
+ b[0] = ICMP6_ECHO_REQUEST
+ return b
+}
+
+func newICMPInfoMessage(id, seqnum, msglen int, filler []byte) []byte {
+ b := make([]byte, msglen)
+ copy(b[8:], bytes.Repeat(filler, (msglen-8)/len(filler)+1))
+ b[0] = 0 // type
+ b[1] = 0 // code
+ b[2] = 0 // checksum
+ b[3] = 0 // checksum
+ b[4] = uint8(id >> 8) // identifier
+ b[5] = uint8(id & 0xff) // identifier
+ b[6] = uint8(seqnum >> 8) // sequence number
+ b[7] = uint8(seqnum & 0xff) // sequence number
+ return b
+}
+
+func parseICMPEchoReply(b []byte) (id, seqnum int) {
+ id = int(b[4])<<8 | int(b[5])
+ seqnum = int(b[6])<<8 | int(b[7])
+ return
}
diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go
index 662b9f57b..b23213ee1 100644
--- a/src/pkg/net/iprawsock.go
+++ b/src/pkg/net/iprawsock.go
@@ -6,10 +6,6 @@
package net
-import (
- "os"
-)
-
// IPAddr represents the address of a IP end point.
type IPAddr struct {
IP IP
@@ -29,7 +25,7 @@ func (a *IPAddr) String() string {
// names to numeric addresses on the network net, which must be
// "ip", "ip4" or "ip6". A literal IPv6 host address must be
// enclosed in square brackets, as in "[::]".
-func ResolveIPAddr(net, addr string) (*IPAddr, os.Error) {
+func ResolveIPAddr(net, addr string) (*IPAddr, error) {
ip, err := hostToIP(net, addr)
if err != nil {
return nil, err
@@ -38,7 +34,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, os.Error) {
}
// Convert "host" into IP address.
-func hostToIP(net, host string) (ip IP, err os.Error) {
+func hostToIP(net, host string) (ip IP, err error) {
var addr IP
// Try as an IP address.
addr = ParseIP(host)
diff --git a/src/pkg/net/iprawsock_plan9.go b/src/pkg/net/iprawsock_plan9.go
index 808e17974..859153c2a 100644
--- a/src/pkg/net/iprawsock_plan9.go
+++ b/src/pkg/net/iprawsock_plan9.go
@@ -8,26 +8,42 @@ package net
import (
"os"
+ "time"
)
// IPConn is the implementation of the Conn and PacketConn
// interfaces for IP network connections.
type IPConn bool
+// SetDeadline implements the net.Conn SetDeadline method.
+func (c *IPConn) SetDeadline(t time.Time) error {
+ return os.EPLAN9
+}
+
+// SetReadDeadline implements the net.Conn SetReadDeadline method.
+func (c *IPConn) SetReadDeadline(t time.Time) error {
+ return os.EPLAN9
+}
+
+// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
+func (c *IPConn) SetWriteDeadline(t time.Time) error {
+ return os.EPLAN9
+}
+
// Implementation of the Conn interface - see Conn for documentation.
-// Read implements the net.Conn Read method.
-func (c *IPConn) Read(b []byte) (n int, err os.Error) {
+// Read implements the Conn Read method.
+func (c *IPConn) Read(b []byte) (int, error) {
return 0, os.EPLAN9
}
-// Write implements the net.Conn Write method.
-func (c *IPConn) Write(b []byte) (n int, err os.Error) {
+// Write implements the Conn Write method.
+func (c *IPConn) Write(b []byte) (int, error) {
return 0, os.EPLAN9
}
// Close closes the IP connection.
-func (c *IPConn) Close() os.Error {
+func (c *IPConn) Close() error {
return os.EPLAN9
}
@@ -41,52 +57,42 @@ func (c *IPConn) RemoteAddr() Addr {
return nil
}
-// SetTimeout implements the net.Conn SetTimeout method.
-func (c *IPConn) SetTimeout(nsec int64) os.Error {
- return os.EPLAN9
-}
-
-// SetReadTimeout implements the net.Conn SetReadTimeout method.
-func (c *IPConn) SetReadTimeout(nsec int64) os.Error {
- return os.EPLAN9
-}
+// IP-specific methods.
-// SetWriteTimeout implements the net.Conn SetWriteTimeout method.
-func (c *IPConn) SetWriteTimeout(nsec int64) os.Error {
- return os.EPLAN9
+// ReadFromIP reads a IP packet from c, copying the payload into b.
+// It returns the number of bytes copied into b and the return address
+// that was on the packet.
+//
+// ReadFromIP can be made to time out and return an error with
+// Timeout() == true after a fixed time limit; see SetDeadline and
+// SetReadDeadline.
+func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) {
+ return 0, nil, os.EPLAN9
}
-// IP-specific methods.
-
-// ReadFrom implements the net.PacketConn ReadFrom method.
-func (c *IPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
- err = os.EPLAN9
- return
+// ReadFrom implements the PacketConn ReadFrom method.
+func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) {
+ return 0, nil, os.EPLAN9
}
// WriteToIP writes a IP packet to addr via c, copying the payload from b.
//
// WriteToIP can be made to time out and return
// an error with Timeout() == true after a fixed time limit;
-// see SetTimeout and SetWriteTimeout.
+// see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
-func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (n int, err os.Error) {
+func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) {
return 0, os.EPLAN9
}
-// WriteTo implements the net.PacketConn WriteTo method.
-func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
+// WriteTo implements the PacketConn WriteTo method.
+func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) {
return 0, os.EPLAN9
}
-func splitNetProto(netProto string) (net string, proto int, err os.Error) {
- err = os.EPLAN9
- return
-}
-
-// DialIP connects to the remote address raddr on the network net,
-// which must be "ip", "ip4", or "ip6".
-func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) {
+// DialIP connects to the remote address raddr on the network protocol netProto,
+// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name.
+func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
return nil, os.EPLAN9
}
@@ -94,6 +100,6 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) {
// local address laddr. The returned connection c's ReadFrom
// and WriteTo methods can be used to receive and send IP
// packets with per-packet addressing.
-func ListenIP(netProto string, laddr *IPAddr) (c *IPConn, err os.Error) {
+func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
return nil, os.EPLAN9
}
diff --git a/src/pkg/net/iprawsock_posix.go b/src/pkg/net/iprawsock_posix.go
index 35aceb223..c34ffeb12 100644
--- a/src/pkg/net/iprawsock_posix.go
+++ b/src/pkg/net/iprawsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd windows
+// +build darwin freebsd linux netbsd openbsd windows
// (Raw) IP sockets
@@ -10,12 +10,10 @@ package net
import (
"os"
- "sync"
"syscall"
+ "time"
)
-var onceReadProtocols sync.Once
-
func sockaddrToIP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
@@ -36,7 +34,7 @@ func (a *IPAddr) family() int {
return syscall.AF_INET6
}
-func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) {
+func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
return ipToSockaddr(family, a.IP, 0)
}
@@ -59,14 +57,14 @@ func (c *IPConn) ok() bool { return c != nil && c.fd != nil }
// Implementation of the Conn interface - see Conn for documentation.
-// Read implements the net.Conn Read method.
-func (c *IPConn) Read(b []byte) (n int, err os.Error) {
- n, _, err = c.ReadFrom(b)
- return
+// Read implements the Conn Read method.
+func (c *IPConn) Read(b []byte) (int, error) {
+ n, _, err := c.ReadFrom(b)
+ return n, err
}
-// Write implements the net.Conn Write method.
-func (c *IPConn) Write(b []byte) (n int, err os.Error) {
+// Write implements the Conn Write method.
+func (c *IPConn) Write(b []byte) (int, error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -74,7 +72,7 @@ func (c *IPConn) Write(b []byte) (n int, err os.Error) {
}
// Close closes the IP connection.
-func (c *IPConn) Close() os.Error {
+func (c *IPConn) Close() error {
if !c.ok() {
return os.EINVAL
}
@@ -99,33 +97,33 @@ func (c *IPConn) RemoteAddr() Addr {
return c.fd.raddr
}
-// SetTimeout implements the net.Conn SetTimeout method.
-func (c *IPConn) SetTimeout(nsec int64) os.Error {
+// SetDeadline implements the Conn SetDeadline method.
+func (c *IPConn) SetDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setTimeout(c.fd, nsec)
+ return setDeadline(c.fd, t)
}
-// SetReadTimeout implements the net.Conn SetReadTimeout method.
-func (c *IPConn) SetReadTimeout(nsec int64) os.Error {
+// SetReadDeadline implements the Conn SetReadDeadline method.
+func (c *IPConn) SetReadDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setReadTimeout(c.fd, nsec)
+ return setReadDeadline(c.fd, t)
}
-// SetWriteTimeout implements the net.Conn SetWriteTimeout method.
-func (c *IPConn) SetWriteTimeout(nsec int64) os.Error {
+// SetWriteDeadline implements the Conn SetWriteDeadline method.
+func (c *IPConn) SetWriteDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setWriteTimeout(c.fd, nsec)
+ return setWriteDeadline(c.fd, t)
}
// SetReadBuffer sets the size of the operating system's
// receive buffer associated with the connection.
-func (c *IPConn) SetReadBuffer(bytes int) os.Error {
+func (c *IPConn) SetReadBuffer(bytes int) error {
if !c.ok() {
return os.EINVAL
}
@@ -134,7 +132,7 @@ func (c *IPConn) SetReadBuffer(bytes int) os.Error {
// SetWriteBuffer sets the size of the operating system's
// transmit buffer associated with the connection.
-func (c *IPConn) SetWriteBuffer(bytes int) os.Error {
+func (c *IPConn) SetWriteBuffer(bytes int) error {
if !c.ok() {
return os.EINVAL
}
@@ -148,14 +146,15 @@ func (c *IPConn) SetWriteBuffer(bytes int) os.Error {
// that was on the packet.
//
// ReadFromIP can be made to time out and return an error with
-// Timeout() == true after a fixed time limit; see SetTimeout and
-// SetReadTimeout.
-func (c *IPConn) ReadFromIP(b []byte) (n int, addr *IPAddr, err os.Error) {
+// Timeout() == true after a fixed time limit; see SetDeadline and
+// SetReadDeadline.
+func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) {
if !c.ok() {
return 0, nil, os.EINVAL
}
// TODO(cw,rsc): consider using readv if we know the family
// type to avoid the header trim/copy
+ var addr *IPAddr
n, sa, err := c.fd.ReadFrom(b)
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
@@ -168,11 +167,11 @@ func (c *IPConn) ReadFromIP(b []byte) (n int, addr *IPAddr, err os.Error) {
case *syscall.SockaddrInet6:
addr = &IPAddr{sa.Addr[0:]}
}
- return
+ return n, addr, err
}
-// ReadFrom implements the net.PacketConn ReadFrom method.
-func (c *IPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
+// ReadFrom implements the PacketConn ReadFrom method.
+func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) {
if !c.ok() {
return 0, nil, os.EINVAL
}
@@ -184,81 +183,37 @@ func (c *IPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
//
// WriteToIP can be made to time out and return
// an error with Timeout() == true after a fixed time limit;
-// see SetTimeout and SetWriteTimeout.
+// see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
-func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (n int, err os.Error) {
+func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) {
if !c.ok() {
return 0, os.EINVAL
}
- sa, err1 := addr.sockaddr(c.fd.family)
- if err1 != nil {
- return 0, &OpError{Op: "write", Net: "ip", Addr: addr, Error: err1}
+ sa, err := addr.sockaddr(c.fd.family)
+ if err != nil {
+ return 0, &OpError{"write", c.fd.net, addr, err}
}
return c.fd.WriteTo(b, sa)
}
-// WriteTo implements the net.PacketConn WriteTo method.
-func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
+// WriteTo implements the PacketConn WriteTo method.
+func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) {
if !c.ok() {
return 0, os.EINVAL
}
a, ok := addr.(*IPAddr)
if !ok {
- return 0, &OpError{"writeto", "ip", addr, os.EINVAL}
+ return 0, &OpError{"write", c.fd.net, addr, os.EINVAL}
}
return c.WriteToIP(b, a)
}
-var protocols map[string]int
-
-func readProtocols() {
- protocols = make(map[string]int)
- if file, err := open("/etc/protocols"); err == nil {
- for line, ok := file.readLine(); ok; line, ok = file.readLine() {
- // tcp 6 TCP # transmission control protocol
- if i := byteIndex(line, '#'); i >= 0 {
- line = line[0:i]
- }
- f := getFields(line)
- if len(f) < 2 {
- continue
- }
- if proto, _, ok := dtoi(f[1], 0); ok {
- protocols[f[0]] = proto
- for _, alias := range f[2:] {
- protocols[alias] = proto
- }
- }
- }
- file.close()
- }
-}
-
-func splitNetProto(netProto string) (net string, proto int, err os.Error) {
- onceReadProtocols.Do(readProtocols)
- i := last(netProto, ':')
- if i < 0 { // no colon
- return "", 0, os.NewError("no IP protocol specified")
- }
- net = netProto[0:i]
- protostr := netProto[i+1:]
- proto, i, ok := dtoi(protostr, 0)
- if !ok || i != len(protostr) {
- // lookup by name
- proto, ok = protocols[protostr]
- if ok {
- return
- }
- }
- return
-}
-
-// DialIP connects to the remote address raddr on the network net,
-// which must be "ip", "ip4", or "ip6".
-func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) {
- net, proto, err := splitNetProto(netProto)
+// DialIP connects to the remote address raddr on the network protocol netProto,
+// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name.
+func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
+ net, proto, err := parseDialNetwork(netProto)
if err != nil {
- return
+ return nil, err
}
switch net {
case "ip", "ip4", "ip6":
@@ -266,11 +221,11 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) {
return nil, UnknownNetworkError(net)
}
if raddr == nil {
- return nil, &OpError{"dial", "ip", nil, errMissingAddress}
+ return nil, &OpError{"dial", netProto, nil, errMissingAddress}
}
- fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_RAW, proto, "dial", sockaddrToIP)
- if e != nil {
- return nil, e
+ fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_RAW, proto, "dial", sockaddrToIP)
+ if err != nil {
+ return nil, err
}
return newIPConn(fd), nil
}
@@ -279,29 +234,24 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) {
// local address laddr. The returned connection c's ReadFrom
// and WriteTo methods can be used to receive and send IP
// packets with per-packet addressing.
-func ListenIP(netProto string, laddr *IPAddr) (c *IPConn, err os.Error) {
- net, proto, err := splitNetProto(netProto)
+func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
+ net, proto, err := parseDialNetwork(netProto)
if err != nil {
- return
+ return nil, err
}
switch net {
case "ip", "ip4", "ip6":
default:
return nil, UnknownNetworkError(net)
}
- fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "dial", sockaddrToIP)
- if e != nil {
- return nil, e
+ fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "listen", sockaddrToIP)
+ if err != nil {
+ return nil, err
}
return newIPConn(fd), nil
}
-// BindToDevice binds an IPConn to a network interface.
-func (c *IPConn) BindToDevice(device string) os.Error {
- if !c.ok() {
- return os.EINVAL
- }
- c.fd.incref()
- defer c.fd.decref()
- return os.NewSyscallError("setsockopt", syscall.BindToDevice(c.fd.sysfd, device))
-}
+// File returns a copy of the underlying os.File, set to blocking mode.
+// It is the caller's responsibility to close f when finished.
+// Closing c does not affect f, and closing f does not affect c.
+func (c *IPConn) File() (f *os.File, err error) { return c.fd.dup() }
diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go
index 4e2a5622b..9234f5aff 100644
--- a/src/pkg/net/ipsock.go
+++ b/src/pkg/net/ipsock.go
@@ -6,14 +6,10 @@
package net
-import (
- "os"
-)
-
var supportsIPv6, supportsIPv4map = probeIPv6Stack()
func firstFavoriteAddr(filter func(IP) IP, addrs []string) (addr IP) {
- if filter == anyaddr {
+ if filter == nil {
// We'll take any IP address, but since the dialing code
// does not yet try multiple addresses, prefer to use
// an IPv4 address if possible. This is especially relevant
@@ -61,14 +57,14 @@ func ipv6only(x IP) IP {
type InvalidAddrError string
-func (e InvalidAddrError) String() string { return string(e) }
+func (e InvalidAddrError) Error() string { return string(e) }
func (e InvalidAddrError) Timeout() bool { return false }
func (e InvalidAddrError) Temporary() bool { return false }
// SplitHostPort splits a network address of the form
// "host:port" or "[host]:port" into host and port.
// The latter form must be used when host contains a colon.
-func SplitHostPort(hostport string) (host, port string, err os.Error) {
+func SplitHostPort(hostport string) (host, port string, err error) {
// The port starts after the last colon.
i := last(hostport, ':')
if i < 0 {
@@ -102,7 +98,7 @@ func JoinHostPort(host, port string) string {
}
// Convert "host:port" into IP address and port.
-func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) {
+func hostPortToIP(net, hostport string) (ip IP, iport int, err error) {
var (
addr IP
p, i int
@@ -117,7 +113,7 @@ func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) {
// Try as an IP address.
addr = ParseIP(host)
if addr == nil {
- filter := anyaddr
+ var filter func(IP) IP
if net != "" && net[len(net)-1] == '4' {
filter = ipv4only
}
diff --git a/src/pkg/net/ipsock_plan9.go b/src/pkg/net/ipsock_plan9.go
index 9e5da6d38..09d8d6b4e 100644
--- a/src/pkg/net/ipsock_plan9.go
+++ b/src/pkg/net/ipsock_plan9.go
@@ -7,7 +7,10 @@
package net
import (
+ "errors"
+ "io"
"os"
+ "time"
)
// probeIPv6Stack returns two boolean values. If the first boolean value is
@@ -18,7 +21,7 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) {
}
// parsePlan9Addr parses address of the form [ip!]port (e.g. 127.0.0.1!80).
-func parsePlan9Addr(s string) (ip IP, iport int, err os.Error) {
+func parsePlan9Addr(s string) (ip IP, iport int, err error) {
var (
addr IP
p, i int
@@ -29,13 +32,13 @@ func parsePlan9Addr(s string) (ip IP, iport int, err os.Error) {
if i >= 0 {
addr = ParseIP(s[:i])
if addr == nil {
- err = os.NewError("net: parsing IP failed")
+ err = errors.New("net: parsing IP failed")
goto Error
}
}
p, _, ok = dtoi(s[i+1:], 0)
if !ok {
- err = os.NewError("net: parsing port failed")
+ err = errors.New("net: parsing port failed")
goto Error
}
if p < 0 || p > 0xFFFF {
@@ -48,7 +51,7 @@ Error:
return nil, 0, err
}
-func readPlan9Addr(proto, filename string) (addr Addr, err os.Error) {
+func readPlan9Addr(proto, filename string) (addr Addr, err error) {
var buf [128]byte
f, err := os.Open(filename)
@@ -69,7 +72,7 @@ func readPlan9Addr(proto, filename string) (addr Addr, err os.Error) {
case "udp":
addr = &UDPAddr{ip, port}
default:
- return nil, os.NewError("unknown protocol " + proto)
+ return nil, errors.New("unknown protocol " + proto)
}
return addr, nil
}
@@ -89,7 +92,7 @@ func (c *plan9Conn) ok() bool { return c != nil && c.ctl != nil }
// Implementation of the Conn interface - see Conn for documentation.
// Read implements the net.Conn Read method.
-func (c *plan9Conn) Read(b []byte) (n int, err os.Error) {
+func (c *plan9Conn) Read(b []byte) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -100,7 +103,7 @@ func (c *plan9Conn) Read(b []byte) (n int, err os.Error) {
}
}
n, err = c.data.Read(b)
- if c.proto == "udp" && err == os.EOF {
+ if c.proto == "udp" && err == io.EOF {
n = 0
err = nil
}
@@ -108,7 +111,7 @@ func (c *plan9Conn) Read(b []byte) (n int, err os.Error) {
}
// Write implements the net.Conn Write method.
-func (c *plan9Conn) Write(b []byte) (n int, err os.Error) {
+func (c *plan9Conn) Write(b []byte) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -122,7 +125,7 @@ func (c *plan9Conn) Write(b []byte) (n int, err os.Error) {
}
// Close closes the connection.
-func (c *plan9Conn) Close() os.Error {
+func (c *plan9Conn) Close() error {
if !c.ok() {
return os.EINVAL
}
@@ -154,31 +157,22 @@ func (c *plan9Conn) RemoteAddr() Addr {
return c.raddr
}
-// SetTimeout implements the net.Conn SetTimeout method.
-func (c *plan9Conn) SetTimeout(nsec int64) os.Error {
- if !c.ok() {
- return os.EINVAL
- }
+// SetDeadline implements the net.Conn SetDeadline method.
+func (c *plan9Conn) SetDeadline(t time.Time) error {
return os.EPLAN9
}
-// SetReadTimeout implements the net.Conn SetReadTimeout method.
-func (c *plan9Conn) SetReadTimeout(nsec int64) os.Error {
- if !c.ok() {
- return os.EINVAL
- }
+// SetReadDeadline implements the net.Conn SetReadDeadline method.
+func (c *plan9Conn) SetReadDeadline(t time.Time) error {
return os.EPLAN9
}
-// SetWriteTimeout implements the net.Conn SetWriteTimeout method.
-func (c *plan9Conn) SetWriteTimeout(nsec int64) os.Error {
- if !c.ok() {
- return os.EINVAL
- }
+// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
+func (c *plan9Conn) SetWriteDeadline(t time.Time) error {
return os.EPLAN9
}
-func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err os.Error) {
+func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) {
var (
ip IP
port int
@@ -213,7 +207,7 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string,
return f, dest, proto, string(buf[:n]), nil
}
-func dialPlan9(net string, laddr, raddr Addr) (c *plan9Conn, err os.Error) {
+func dialPlan9(net string, laddr, raddr Addr) (c *plan9Conn, err error) {
f, dest, proto, name, err := startPlan9(net, raddr)
if err != nil {
return
@@ -239,7 +233,7 @@ type plan9Listener struct {
laddr Addr
}
-func listenPlan9(net string, laddr Addr) (l *plan9Listener, err os.Error) {
+func listenPlan9(net string, laddr Addr) (l *plan9Listener, err error) {
f, dest, proto, name, err := startPlan9(net, laddr)
if err != nil {
return
@@ -265,7 +259,7 @@ func (l *plan9Listener) plan9Conn() *plan9Conn {
return newPlan9Conn(l.proto, l.name, l.ctl, l.laddr, nil)
}
-func (l *plan9Listener) acceptPlan9() (c *plan9Conn, err os.Error) {
+func (l *plan9Listener) acceptPlan9() (c *plan9Conn, err error) {
f, err := os.Open(l.dir + "/listen")
if err != nil {
return
@@ -287,7 +281,7 @@ func (l *plan9Listener) acceptPlan9() (c *plan9Conn, err os.Error) {
return newPlan9Conn(l.proto, name, f, laddr, raddr), nil
}
-func (l *plan9Listener) Accept() (c Conn, err os.Error) {
+func (l *plan9Listener) Accept() (c Conn, err error) {
c1, err := l.acceptPlan9()
if err != nil {
return
@@ -295,7 +289,7 @@ func (l *plan9Listener) Accept() (c Conn, err os.Error) {
return c1, nil
}
-func (l *plan9Listener) Close() os.Error {
+func (l *plan9Listener) Close() error {
if l == nil || l.ctl == nil {
return os.EINVAL
}
diff --git a/src/pkg/net/ipsock_posix.go b/src/pkg/net/ipsock_posix.go
index 049df9ea4..3a059f516 100644
--- a/src/pkg/net/ipsock_posix.go
+++ b/src/pkg/net/ipsock_posix.go
@@ -2,14 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd windows
+// +build darwin freebsd linux netbsd openbsd windows
package net
-import (
- "os"
- "syscall"
-)
+import "syscall"
// Should we try to use the IPv4 socket interface if we're
// only dealing with IPv4 sockets? As long as the host system
@@ -36,8 +33,8 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) {
}
for i := range probes {
- s, errno := syscall.Socket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
- if errno != 0 {
+ s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
+ if err != nil {
continue
}
defer closesocket(s)
@@ -45,8 +42,8 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) {
if err != nil {
continue
}
- errno = syscall.Bind(s, sa)
- if errno != 0 {
+ err = syscall.Bind(s, sa)
+ if err != nil {
continue
}
probes[i].ok = true
@@ -94,23 +91,18 @@ func favoriteAddrFamily(net string, raddr, laddr sockaddr, mode string) int {
return syscall.AF_INET6
}
-// TODO(rsc): if syscall.OS == "linux", we're supposed to read
-// /proc/sys/net/core/somaxconn,
-// to take advantage of kernels that have raised the limit.
-func listenBacklog() int { return syscall.SOMAXCONN }
-
// Internet sockets (TCP, UDP)
// A sockaddr represents a TCP or UDP network address that can
// be converted into a syscall.Sockaddr.
type sockaddr interface {
Addr
- sockaddr(family int) (syscall.Sockaddr, os.Error)
+ sockaddr(family int) (syscall.Sockaddr, error)
family() int
}
-func internetSocket(net string, laddr, raddr sockaddr, socktype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err os.Error) {
- var oserr os.Error
+func internetSocket(net string, laddr, raddr sockaddr, sotype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
+ var oserr error
var la, ra syscall.Sockaddr
family := favoriteAddrFamily(net, raddr, laddr, mode)
if laddr != nil {
@@ -123,7 +115,7 @@ func internetSocket(net string, laddr, raddr sockaddr, socktype, proto int, mode
goto Error
}
}
- fd, oserr = socket(net, family, socktype, proto, la, ra, toAddr)
+ fd, oserr = socket(net, family, sotype, proto, la, ra, toAddr)
if oserr != nil {
goto Error
}
@@ -137,7 +129,7 @@ Error:
return nil, &OpError{mode, net, addr, oserr}
}
-func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, os.Error) {
+func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, error) {
switch family {
case syscall.AF_INET:
if len(ip) == 0 {
diff --git a/src/pkg/net/lookup_plan9.go b/src/pkg/net/lookup_plan9.go
index ee0c9e879..c0bb9225a 100644
--- a/src/pkg/net/lookup_plan9.go
+++ b/src/pkg/net/lookup_plan9.go
@@ -5,10 +5,11 @@
package net
import (
+ "errors"
"os"
)
-func query(filename, query string, bufSize int) (res []string, err os.Error) {
+func query(filename, query string, bufSize int) (res []string, err error) {
file, err := os.OpenFile(filename, os.O_RDWR, 0)
if err != nil {
return
@@ -34,7 +35,7 @@ func query(filename, query string, bufSize int) (res []string, err os.Error) {
return
}
-func queryCS(net, host, service string) (res []string, err os.Error) {
+func queryCS(net, host, service string) (res []string, err error) {
switch net {
case "tcp4", "tcp6":
net = "tcp"
@@ -47,9 +48,9 @@ func queryCS(net, host, service string) (res []string, err os.Error) {
return query("/net/cs", net+"!"+host+"!"+service, 128)
}
-func queryCS1(net string, ip IP, port int) (clone, dest string, err os.Error) {
+func queryCS1(net string, ip IP, port int) (clone, dest string, err error) {
ips := "*"
- if !ip.IsUnspecified() {
+ if len(ip) != 0 && !ip.IsUnspecified() {
ips = ip.String()
}
lines, err := queryCS(net, ips, itoa(port))
@@ -58,19 +59,22 @@ func queryCS1(net string, ip IP, port int) (clone, dest string, err os.Error) {
}
f := getFields(lines[0])
if len(f) < 2 {
- return "", "", os.NewError("net: bad response from ndb/cs")
+ return "", "", errors.New("net: bad response from ndb/cs")
}
clone, dest = f[0], f[1]
return
}
-func queryDNS(addr string, typ string) (res []string, err os.Error) {
+func queryDNS(addr string, typ string) (res []string, err error) {
return query("/net/dns", addr+" "+typ, 1024)
}
-// LookupHost looks up the given host using the local resolver.
-// It returns an array of that host's addresses.
-func LookupHost(host string) (addrs []string, err os.Error) {
+func lookupProtocol(name string) (proto int, err error) {
+ // TODO: Implement this
+ return 0, os.EPLAN9
+}
+
+func lookupHost(host string) (addrs []string, err error) {
// Use /net/cs insead of /net/dns because cs knows about
// host names in local network (e.g. from /lib/ndb/local)
lines, err := queryCS("tcp", host, "1")
@@ -94,9 +98,7 @@ func LookupHost(host string) (addrs []string, err os.Error) {
return
}
-// LookupIP looks up host using the local resolver.
-// It returns an array of that host's IPv4 and IPv6 addresses.
-func LookupIP(host string) (ips []IP, err os.Error) {
+func lookupIP(host string) (ips []IP, err error) {
addrs, err := LookupHost(host)
if err != nil {
return
@@ -109,8 +111,7 @@ func LookupIP(host string) (ips []IP, err os.Error) {
return
}
-// LookupPort looks up the port for the given network and service.
-func LookupPort(network, service string) (port int, err os.Error) {
+func lookupPort(network, service string) (port int, err error) {
switch network {
case "tcp4", "tcp6":
network = "tcp"
@@ -139,11 +140,7 @@ func LookupPort(network, service string) (port int, err os.Error) {
return 0, unknownPortError
}
-// LookupCNAME returns the canonical DNS host for the given name.
-// Callers that do not care about the canonical name can call
-// LookupHost or LookupIP directly; both take care of resolving
-// the canonical name as part of the lookup.
-func LookupCNAME(name string) (cname string, err os.Error) {
+func lookupCNAME(name string) (cname string, err error) {
lines, err := queryDNS(name, "cname")
if err != nil {
return
@@ -153,16 +150,16 @@ func LookupCNAME(name string) (cname string, err os.Error) {
return f[2] + ".", nil
}
}
- return "", os.NewError("net: bad response from ndb/dns")
+ return "", errors.New("net: bad response from ndb/dns")
}
-// LookupSRV tries to resolve an SRV query of the given service,
-// protocol, and domain name, as specified in RFC 2782. In most cases
-// the proto argument can be the same as the corresponding
-// Addr.Network(). The returned records are sorted by priority
-// and randomized by weight within a priority.
-func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) {
- target := "_" + service + "._" + proto + "." + name
+func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
+ var target string
+ if service == "" && proto == "" {
+ target = name
+ } else {
+ target = "_" + service + "._" + proto + "." + name
+ }
lines, err := queryDNS(target, "srv")
if err != nil {
return
@@ -185,8 +182,7 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.
return
}
-// LookupMX returns the DNS MX records for the given domain name sorted by preference.
-func LookupMX(name string) (mx []*MX, err os.Error) {
+func lookupMX(name string) (mx []*MX, err error) {
lines, err := queryDNS(name, "mx")
if err != nil {
return
@@ -204,14 +200,20 @@ func LookupMX(name string) (mx []*MX, err os.Error) {
return
}
-// LookupTXT returns the DNS TXT records for the given domain name.
-func LookupTXT(name string) (txt []string, err os.Error) {
- return nil, os.NewError("net.LookupTXT is not implemented on Plan 9")
+func lookupTXT(name string) (txt []string, err error) {
+ lines, err := queryDNS(name, "txt")
+ if err != nil {
+ return
+ }
+ for _, line := range lines {
+ if i := byteIndex(line, '\t'); i >= 0 {
+ txt = append(txt, line[i+1:])
+ }
+ }
+ return
}
-// LookupAddr performs a reverse lookup for the given address, returning a list
-// of names mapping to that address.
-func LookupAddr(addr string) (name []string, err os.Error) {
+func lookupAddr(addr string) (name []string, err error) {
arpa, err := reverseaddr(addr)
if err != nil {
return
diff --git a/src/pkg/net/lookup_test.go b/src/pkg/net/lookup_test.go
index 41066fe48..9a39ca8a1 100644
--- a/src/pkg/net/lookup_test.go
+++ b/src/pkg/net/lookup_test.go
@@ -26,6 +26,15 @@ func TestGoogleSRV(t *testing.T) {
if len(addrs) == 0 {
t.Errorf("no results")
}
+
+ // Non-standard back door.
+ _, addrs, err = LookupSRV("", "", "_xmpp-server._tcp.google.com")
+ if err != nil {
+ t.Errorf("back door failed: %s", err)
+ }
+ if len(addrs) == 0 {
+ t.Errorf("back door no results")
+ }
}
func TestGmailMX(t *testing.T) {
@@ -43,10 +52,6 @@ func TestGmailMX(t *testing.T) {
}
func TestGmailTXT(t *testing.T) {
- if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
- t.Logf("LookupTXT is not implemented on Windows or Plan 9")
- return
- }
if testing.Short() || avoidMacFirewall {
t.Logf("skipping test to avoid external network")
return
diff --git a/src/pkg/net/lookup_unix.go b/src/pkg/net/lookup_unix.go
index 7368b751e..d500a1240 100644
--- a/src/pkg/net/lookup_unix.go
+++ b/src/pkg/net/lookup_unix.go
@@ -2,17 +2,57 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd
+// +build darwin freebsd linux netbsd openbsd
package net
import (
- "os"
+ "errors"
+ "sync"
)
-// LookupHost looks up the given host using the local resolver.
-// It returns an array of that host's addresses.
-func LookupHost(host string) (addrs []string, err os.Error) {
+var (
+ protocols map[string]int
+ onceReadProtocols sync.Once
+)
+
+// readProtocols loads contents of /etc/protocols into protocols map
+// for quick access.
+func readProtocols() {
+ protocols = make(map[string]int)
+ if file, err := open("/etc/protocols"); err == nil {
+ for line, ok := file.readLine(); ok; line, ok = file.readLine() {
+ // tcp 6 TCP # transmission control protocol
+ if i := byteIndex(line, '#'); i >= 0 {
+ line = line[0:i]
+ }
+ f := getFields(line)
+ if len(f) < 2 {
+ continue
+ }
+ if proto, _, ok := dtoi(f[1], 0); ok {
+ protocols[f[0]] = proto
+ for _, alias := range f[2:] {
+ protocols[alias] = proto
+ }
+ }
+ }
+ file.close()
+ }
+}
+
+// lookupProtocol looks up IP protocol name in /etc/protocols and
+// returns correspondent protocol number.
+func lookupProtocol(name string) (proto int, err error) {
+ onceReadProtocols.Do(readProtocols)
+ proto, found := protocols[name]
+ if !found {
+ return 0, errors.New("unknown IP protocol specified: " + name)
+ }
+ return
+}
+
+func lookupHost(host string) (addrs []string, err error) {
addrs, err, ok := cgoLookupHost(host)
if !ok {
addrs, err = goLookupHost(host)
@@ -20,9 +60,7 @@ func LookupHost(host string) (addrs []string, err os.Error) {
return
}
-// LookupIP looks up host using the local resolver.
-// It returns an array of that host's IPv4 and IPv6 addresses.
-func LookupIP(host string) (addrs []IP, err os.Error) {
+func lookupIP(host string) (addrs []IP, err error) {
addrs, err, ok := cgoLookupIP(host)
if !ok {
addrs, err = goLookupIP(host)
@@ -30,8 +68,7 @@ func LookupIP(host string) (addrs []IP, err os.Error) {
return
}
-// LookupPort looks up the port for the given network and service.
-func LookupPort(network, service string) (port int, err os.Error) {
+func lookupPort(network, service string) (port int, err error) {
port, err, ok := cgoLookupPort(network, service)
if !ok {
port, err = goLookupPort(network, service)
@@ -39,11 +76,7 @@ func LookupPort(network, service string) (port int, err os.Error) {
return
}
-// LookupCNAME returns the canonical DNS host for the given name.
-// Callers that do not care about the canonical name can call
-// LookupHost or LookupIP directly; both take care of resolving
-// the canonical name as part of the lookup.
-func LookupCNAME(name string) (cname string, err os.Error) {
+func lookupCNAME(name string) (cname string, err error) {
cname, err, ok := cgoLookupCNAME(name)
if !ok {
cname, err = goLookupCNAME(name)
@@ -51,13 +84,13 @@ func LookupCNAME(name string) (cname string, err os.Error) {
return
}
-// LookupSRV tries to resolve an SRV query of the given service,
-// protocol, and domain name, as specified in RFC 2782. In most cases
-// the proto argument can be the same as the corresponding
-// Addr.Network(). The returned records are sorted by priority
-// and randomized by weight within a priority.
-func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) {
- target := "_" + service + "._" + proto + "." + name
+func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
+ var target string
+ if service == "" && proto == "" {
+ target = name
+ } else {
+ target = "_" + service + "._" + proto + "." + name
+ }
var records []dnsRR
cname, records, err = lookup(target, dnsTypeSRV)
if err != nil {
@@ -72,8 +105,7 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.
return
}
-// LookupMX returns the DNS MX records for the given domain name sorted by preference.
-func LookupMX(name string) (mx []*MX, err os.Error) {
+func lookupMX(name string) (mx []*MX, err error) {
_, records, err := lookup(name, dnsTypeMX)
if err != nil {
return
@@ -87,8 +119,7 @@ func LookupMX(name string) (mx []*MX, err os.Error) {
return
}
-// LookupTXT returns the DNS TXT records for the given domain name.
-func LookupTXT(name string) (txt []string, err os.Error) {
+func lookupTXT(name string) (txt []string, err error) {
_, records, err := lookup(name, dnsTypeTXT)
if err != nil {
return
@@ -100,9 +131,7 @@ func LookupTXT(name string) (txt []string, err os.Error) {
return
}
-// LookupAddr performs a reverse lookup for the given address, returning a list
-// of names mapping to that address.
-func LookupAddr(addr string) (name []string, err os.Error) {
+func lookupAddr(addr string) (name []string, err error) {
name = lookupStaticAddr(addr)
if len(name) > 0 {
return
diff --git a/src/pkg/net/lookup_windows.go b/src/pkg/net/lookup_windows.go
index b33c7f949..dfe2ff6f1 100644
--- a/src/pkg/net/lookup_windows.go
+++ b/src/pkg/net/lookup_windows.go
@@ -5,16 +5,30 @@
package net
import (
- "syscall"
- "unsafe"
"os"
"sync"
+ "syscall"
+ "unsafe"
)
-var hostentLock sync.Mutex
-var serventLock sync.Mutex
+var (
+ protoentLock sync.Mutex
+ hostentLock sync.Mutex
+ serventLock sync.Mutex
+)
+
+// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
+func lookupProtocol(name string) (proto int, err error) {
+ protoentLock.Lock()
+ defer protoentLock.Unlock()
+ p, e := syscall.GetProtoByName(name)
+ if e != nil {
+ return 0, os.NewSyscallError("GetProtoByName", e)
+ }
+ return int(p.Proto), nil
+}
-func LookupHost(name string) (addrs []string, err os.Error) {
+func lookupHost(name string) (addrs []string, err error) {
ips, err := LookupIP(name)
if err != nil {
return
@@ -26,11 +40,11 @@ func LookupHost(name string) (addrs []string, err os.Error) {
return
}
-func LookupIP(name string) (addrs []IP, err os.Error) {
+func lookupIP(name string) (addrs []IP, err error) {
hostentLock.Lock()
defer hostentLock.Unlock()
h, e := syscall.GetHostByName(name)
- if e != 0 {
+ if e != nil {
return nil, os.NewSyscallError("GetHostByName", e)
}
switch h.AddrType {
@@ -47,7 +61,7 @@ func LookupIP(name string) (addrs []IP, err os.Error) {
return addrs, nil
}
-func LookupPort(network, service string) (port int, err os.Error) {
+func lookupPort(network, service string) (port int, err error) {
switch network {
case "tcp4", "tcp6":
network = "tcp"
@@ -57,17 +71,17 @@ func LookupPort(network, service string) (port int, err os.Error) {
serventLock.Lock()
defer serventLock.Unlock()
s, e := syscall.GetServByName(service, network)
- if e != 0 {
+ if e != nil {
return 0, os.NewSyscallError("GetServByName", e)
}
return int(syscall.Ntohs(s.Port)), nil
}
-func LookupCNAME(name string) (cname string, err os.Error) {
+func lookupCNAME(name string) (cname string, err error) {
var r *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
- if int(e) != 0 {
- return "", os.NewSyscallError("LookupCNAME", int(e))
+ if e != nil {
+ return "", os.NewSyscallError("LookupCNAME", e)
}
defer syscall.DnsRecordListFree(r, 1)
if r != nil && r.Type == syscall.DNS_TYPE_CNAME {
@@ -77,12 +91,17 @@ func LookupCNAME(name string) (cname string, err os.Error) {
return
}
-func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) {
+func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
+ var target string
+ if service == "" && proto == "" {
+ target = name
+ } else {
+ target = "_" + service + "._" + proto + "." + name
+ }
var r *syscall.DNSRecord
- target := "_" + service + "._" + proto + "." + name
e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
- if int(e) != 0 {
- return "", nil, os.NewSyscallError("LookupSRV", int(e))
+ if e != nil {
+ return "", nil, os.NewSyscallError("LookupSRV", e)
}
defer syscall.DnsRecordListFree(r, 1)
addrs = make([]*SRV, 0, 10)
@@ -94,11 +113,11 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.
return name, addrs, nil
}
-func LookupMX(name string) (mx []*MX, err os.Error) {
+func lookupMX(name string) (mx []*MX, err error) {
var r *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
- if int(e) != 0 {
- return nil, os.NewSyscallError("LookupMX", int(e))
+ if e != nil {
+ return nil, os.NewSyscallError("LookupMX", e)
}
defer syscall.DnsRecordListFree(r, 1)
mx = make([]*MX, 0, 10)
@@ -110,19 +129,33 @@ func LookupMX(name string) (mx []*MX, err os.Error) {
return mx, nil
}
-func LookupTXT(name string) (txt []string, err os.Error) {
- return nil, os.NewError("net.LookupTXT is not implemented on Windows")
+func lookupTXT(name string) (txt []string, err error) {
+ var r *syscall.DNSRecord
+ e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
+ if e != nil {
+ return nil, os.NewSyscallError("LookupTXT", e)
+ }
+ defer syscall.DnsRecordListFree(r, 1)
+ txt = make([]string, 0, 10)
+ if r != nil && r.Type == syscall.DNS_TYPE_TEXT {
+ d := (*syscall.DNSTXTData)(unsafe.Pointer(&r.Data[0]))
+ for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
+ s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
+ txt = append(txt, s)
+ }
+ }
+ return
}
-func LookupAddr(addr string) (name []string, err os.Error) {
+func lookupAddr(addr string) (name []string, err error) {
arpa, err := reverseaddr(addr)
if err != nil {
return nil, err
}
var r *syscall.DNSRecord
e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
- if int(e) != 0 {
- return nil, os.NewSyscallError("LookupAddr", int(e))
+ if e != nil {
+ return nil, os.NewSyscallError("LookupAddr", e)
}
defer syscall.DnsRecordListFree(r, 1)
name = make([]string, 0, 10)
diff --git a/src/pkg/net/mail/Makefile b/src/pkg/net/mail/Makefile
new file mode 100644
index 000000000..acb1c2a6d
--- /dev/null
+++ b/src/pkg/net/mail/Makefile
@@ -0,0 +1,11 @@
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../Make.inc
+
+TARG=net/mail
+GOFILES=\
+ message.go\
+
+include ../../../Make.pkg
diff --git a/src/pkg/net/mail/message.go b/src/pkg/net/mail/message.go
new file mode 100644
index 000000000..bf22c711e
--- /dev/null
+++ b/src/pkg/net/mail/message.go
@@ -0,0 +1,524 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package mail implements parsing of mail messages.
+
+For the most part, this package follows the syntax as specified by RFC 5322.
+Notable divergences:
+ * Obsolete address formats are not parsed, including addresses with
+ embedded route information.
+ * Group addresses are not parsed.
+ * The full range of spacing (the CFWS syntax element) is not supported,
+ such as breaking addresses across lines.
+*/
+package mail
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net/textproto"
+ "strconv"
+ "strings"
+ "time"
+)
+
+var debug = debugT(false)
+
+type debugT bool
+
+func (d debugT) Printf(format string, args ...interface{}) {
+ if d {
+ log.Printf(format, args...)
+ }
+}
+
+// A Message represents a parsed mail message.
+type Message struct {
+ Header Header
+ Body io.Reader
+}
+
+// ReadMessage reads a message from r.
+// The headers are parsed, and the body of the message will be reading from r.
+func ReadMessage(r io.Reader) (msg *Message, err error) {
+ tp := textproto.NewReader(bufio.NewReader(r))
+
+ hdr, err := tp.ReadMIMEHeader()
+ if err != nil {
+ return nil, err
+ }
+
+ return &Message{
+ Header: Header(hdr),
+ Body: tp.R,
+ }, nil
+}
+
+// Layouts suitable for passing to time.Parse.
+// These are tried in order.
+var dateLayouts []string
+
+func init() {
+ // Generate layouts based on RFC 5322, section 3.3.
+
+ dows := [...]string{"", "Mon, "} // day-of-week
+ days := [...]string{"2", "02"} // day = 1*2DIGIT
+ years := [...]string{"2006", "06"} // year = 4*DIGIT / 2*DIGIT
+ seconds := [...]string{":05", ""} // second
+ zones := [...]string{"-0700", "MST"} // zone = (("+" / "-") 4DIGIT) / "GMT" / ...
+
+ for _, dow := range dows {
+ for _, day := range days {
+ for _, year := range years {
+ for _, second := range seconds {
+ for _, zone := range zones {
+ s := dow + day + " Jan " + year + " 15:04" + second + " " + zone
+ dateLayouts = append(dateLayouts, s)
+ }
+ }
+ }
+ }
+ }
+}
+
+func parseDate(date string) (time.Time, error) {
+ for _, layout := range dateLayouts {
+ t, err := time.Parse(layout, date)
+ if err == nil {
+ return t, nil
+ }
+ }
+ return time.Time{}, errors.New("mail: header could not be parsed")
+}
+
+// A Header represents the key-value pairs in a mail message header.
+type Header map[string][]string
+
+// Get gets the first value associated with the given key.
+// If there are no values associated with the key, Get returns "".
+func (h Header) Get(key string) string {
+ return textproto.MIMEHeader(h).Get(key)
+}
+
+var ErrHeaderNotPresent = errors.New("mail: header not in message")
+
+// Date parses the Date header field.
+func (h Header) Date() (time.Time, error) {
+ hdr := h.Get("Date")
+ if hdr == "" {
+ return time.Time{}, ErrHeaderNotPresent
+ }
+ return parseDate(hdr)
+}
+
+// AddressList parses the named header field as a list of addresses.
+func (h Header) AddressList(key string) ([]*Address, error) {
+ hdr := h.Get(key)
+ if hdr == "" {
+ return nil, ErrHeaderNotPresent
+ }
+ return newAddrParser(hdr).parseAddressList()
+}
+
+// Address represents a single mail address.
+// An address such as "Barry Gibbs <bg@example.com>" is represented
+// as Address{Name: "Barry Gibbs", Address: "bg@example.com"}.
+type Address struct {
+ Name string // Proper name; may be empty.
+ Address string // user@domain
+}
+
+// String formats the address as a valid RFC 5322 address.
+// If the address's name contains non-ASCII characters
+// the name will be rendered according to RFC 2047.
+func (a *Address) String() string {
+ s := "<" + a.Address + ">"
+ if a.Name == "" {
+ return s
+ }
+ // If every character is printable ASCII, quoting is simple.
+ allPrintable := true
+ for i := 0; i < len(a.Name); i++ {
+ if !isVchar(a.Name[i]) {
+ allPrintable = false
+ break
+ }
+ }
+ if allPrintable {
+ b := bytes.NewBufferString(`"`)
+ for i := 0; i < len(a.Name); i++ {
+ if !isQtext(a.Name[i]) {
+ b.WriteByte('\\')
+ }
+ b.WriteByte(a.Name[i])
+ }
+ b.WriteString(`" `)
+ b.WriteString(s)
+ return b.String()
+ }
+
+ // UTF-8 "Q" encoding
+ b := bytes.NewBufferString("=?utf-8?q?")
+ for i := 0; i < len(a.Name); i++ {
+ switch c := a.Name[i]; {
+ case c == ' ':
+ b.WriteByte('_')
+ case isVchar(c) && c != '=' && c != '?' && c != '_':
+ b.WriteByte(c)
+ default:
+ fmt.Fprintf(b, "=%02X", c)
+ }
+ }
+ b.WriteString("?= ")
+ b.WriteString(s)
+ return b.String()
+}
+
+type addrParser []byte
+
+func newAddrParser(s string) *addrParser {
+ p := addrParser(s)
+ return &p
+}
+
+func (p *addrParser) parseAddressList() ([]*Address, error) {
+ var list []*Address
+ for {
+ p.skipSpace()
+ addr, err := p.parseAddress()
+ if err != nil {
+ return nil, err
+ }
+ list = append(list, addr)
+
+ p.skipSpace()
+ if p.empty() {
+ break
+ }
+ if !p.consume(',') {
+ return nil, errors.New("mail: expected comma")
+ }
+ }
+ return list, nil
+}
+
+// parseAddress parses a single RFC 5322 address at the start of p.
+func (p *addrParser) parseAddress() (addr *Address, err error) {
+ debug.Printf("parseAddress: %q", *p)
+ p.skipSpace()
+ if p.empty() {
+ return nil, errors.New("mail: no address")
+ }
+
+ // address = name-addr / addr-spec
+ // TODO(dsymonds): Support parsing group address.
+
+ // addr-spec has a more restricted grammar than name-addr,
+ // so try parsing it first, and fallback to name-addr.
+ // TODO(dsymonds): Is this really correct?
+ spec, err := p.consumeAddrSpec()
+ if err == nil {
+ return &Address{
+ Address: spec,
+ }, err
+ }
+ debug.Printf("parseAddress: not an addr-spec: %v", err)
+ debug.Printf("parseAddress: state is now %q", *p)
+
+ // display-name
+ var displayName string
+ if p.peek() != '<' {
+ displayName, err = p.consumePhrase()
+ if err != nil {
+ return nil, err
+ }
+ }
+ debug.Printf("parseAddress: displayName=%q", displayName)
+
+ // angle-addr = "<" addr-spec ">"
+ p.skipSpace()
+ if !p.consume('<') {
+ return nil, errors.New("mail: no angle-addr")
+ }
+ spec, err = p.consumeAddrSpec()
+ if err != nil {
+ return nil, err
+ }
+ if !p.consume('>') {
+ return nil, errors.New("mail: unclosed angle-addr")
+ }
+ debug.Printf("parseAddress: spec=%q", spec)
+
+ return &Address{
+ Name: displayName,
+ Address: spec,
+ }, nil
+}
+
+// consumeAddrSpec parses a single RFC 5322 addr-spec at the start of p.
+func (p *addrParser) consumeAddrSpec() (spec string, err error) {
+ debug.Printf("consumeAddrSpec: %q", *p)
+
+ orig := *p
+ defer func() {
+ if err != nil {
+ *p = orig
+ }
+ }()
+
+ // local-part = dot-atom / quoted-string
+ var localPart string
+ p.skipSpace()
+ if p.empty() {
+ return "", errors.New("mail: no addr-spec")
+ }
+ if p.peek() == '"' {
+ // quoted-string
+ debug.Printf("consumeAddrSpec: parsing quoted-string")
+ localPart, err = p.consumeQuotedString()
+ } else {
+ // dot-atom
+ debug.Printf("consumeAddrSpec: parsing dot-atom")
+ localPart, err = p.consumeAtom(true)
+ }
+ if err != nil {
+ debug.Printf("consumeAddrSpec: failed: %v", err)
+ return "", err
+ }
+
+ if !p.consume('@') {
+ return "", errors.New("mail: missing @ in addr-spec")
+ }
+
+ // domain = dot-atom / domain-literal
+ var domain string
+ p.skipSpace()
+ if p.empty() {
+ return "", errors.New("mail: no domain in addr-spec")
+ }
+ // TODO(dsymonds): Handle domain-literal
+ domain, err = p.consumeAtom(true)
+ if err != nil {
+ return "", err
+ }
+
+ return localPart + "@" + domain, nil
+}
+
+// consumePhrase parses the RFC 5322 phrase at the start of p.
+func (p *addrParser) consumePhrase() (phrase string, err error) {
+ debug.Printf("consumePhrase: [%s]", *p)
+ // phrase = 1*word
+ var words []string
+ for {
+ // word = atom / quoted-string
+ var word string
+ p.skipSpace()
+ if p.empty() {
+ return "", errors.New("mail: missing phrase")
+ }
+ if p.peek() == '"' {
+ // quoted-string
+ word, err = p.consumeQuotedString()
+ } else {
+ // atom
+ word, err = p.consumeAtom(false)
+ }
+
+ // RFC 2047 encoded-word starts with =?, ends with ?=, and has two other ?s.
+ if err == nil && strings.HasPrefix(word, "=?") && strings.HasSuffix(word, "?=") && strings.Count(word, "?") == 4 {
+ word, err = decodeRFC2047Word(word)
+ }
+
+ if err != nil {
+ break
+ }
+ debug.Printf("consumePhrase: consumed %q", word)
+ words = append(words, word)
+ }
+ // Ignore any error if we got at least one word.
+ if err != nil && len(words) == 0 {
+ debug.Printf("consumePhrase: hit err: %v", err)
+ return "", errors.New("mail: missing word in phrase")
+ }
+ phrase = strings.Join(words, " ")
+ return phrase, nil
+}
+
+// consumeQuotedString parses the quoted string at the start of p.
+func (p *addrParser) consumeQuotedString() (qs string, err error) {
+ // Assume first byte is '"'.
+ i := 1
+ qsb := make([]byte, 0, 10)
+Loop:
+ for {
+ if i >= p.len() {
+ return "", errors.New("mail: unclosed quoted-string")
+ }
+ switch c := (*p)[i]; {
+ case c == '"':
+ break Loop
+ case c == '\\':
+ if i+1 == p.len() {
+ return "", errors.New("mail: unclosed quoted-string")
+ }
+ qsb = append(qsb, (*p)[i+1])
+ i += 2
+ case isQtext(c), c == ' ' || c == '\t':
+ // qtext (printable US-ASCII excluding " and \), or
+ // FWS (almost; we're ignoring CRLF)
+ qsb = append(qsb, c)
+ i++
+ default:
+ return "", fmt.Errorf("mail: bad character in quoted-string: %q", c)
+ }
+ }
+ *p = (*p)[i+1:]
+ return string(qsb), nil
+}
+
+// consumeAtom parses an RFC 5322 atom at the start of p.
+// If dot is true, consumeAtom parses an RFC 5322 dot-atom instead.
+func (p *addrParser) consumeAtom(dot bool) (atom string, err error) {
+ if !isAtext(p.peek(), false) {
+ return "", errors.New("mail: invalid string")
+ }
+ i := 1
+ for ; i < p.len() && isAtext((*p)[i], dot); i++ {
+ }
+ // TODO(dsymonds): Remove the []byte() conversion here when 6g doesn't need it.
+ atom, *p = string([]byte((*p)[:i])), (*p)[i:]
+ return atom, nil
+}
+
+func (p *addrParser) consume(c byte) bool {
+ if p.empty() || p.peek() != c {
+ return false
+ }
+ *p = (*p)[1:]
+ return true
+}
+
+// skipSpace skips the leading space and tab characters.
+func (p *addrParser) skipSpace() {
+ *p = bytes.TrimLeft(*p, " \t")
+}
+
+func (p *addrParser) peek() byte {
+ return (*p)[0]
+}
+
+func (p *addrParser) empty() bool {
+ return p.len() == 0
+}
+
+func (p *addrParser) len() int {
+ return len(*p)
+}
+
+func decodeRFC2047Word(s string) (string, error) {
+ fields := strings.Split(s, "?")
+ if len(fields) != 5 || fields[0] != "=" || fields[4] != "=" {
+ return "", errors.New("mail: address not RFC 2047 encoded")
+ }
+ charset, enc := strings.ToLower(fields[1]), strings.ToLower(fields[2])
+ if charset != "iso-8859-1" && charset != "utf-8" {
+ return "", fmt.Errorf("mail: charset not supported: %q", charset)
+ }
+
+ in := bytes.NewBufferString(fields[3])
+ var r io.Reader
+ switch enc {
+ case "b":
+ r = base64.NewDecoder(base64.StdEncoding, in)
+ case "q":
+ r = qDecoder{r: in}
+ default:
+ return "", fmt.Errorf("mail: RFC 2047 encoding not supported: %q", enc)
+ }
+
+ dec, err := ioutil.ReadAll(r)
+ if err != nil {
+ return "", err
+ }
+
+ switch charset {
+ case "iso-8859-1":
+ b := new(bytes.Buffer)
+ for _, c := range dec {
+ b.WriteRune(rune(c))
+ }
+ return b.String(), nil
+ case "utf-8":
+ return string(dec), nil
+ }
+ panic("unreachable")
+}
+
+type qDecoder struct {
+ r io.Reader
+ scratch [2]byte
+}
+
+func (qd qDecoder) Read(p []byte) (n int, err error) {
+ // This method writes at most one byte into p.
+ if len(p) == 0 {
+ return 0, nil
+ }
+ if _, err := qd.r.Read(qd.scratch[:1]); err != nil {
+ return 0, err
+ }
+ switch c := qd.scratch[0]; {
+ case c == '=':
+ if _, err := io.ReadFull(qd.r, qd.scratch[:2]); err != nil {
+ return 0, err
+ }
+ x, err := strconv.ParseInt(string(qd.scratch[:2]), 16, 64)
+ if err != nil {
+ return 0, fmt.Errorf("mail: invalid RFC 2047 encoding: %q", qd.scratch[:2])
+ }
+ p[0] = byte(x)
+ case c == '_':
+ p[0] = ' '
+ default:
+ p[0] = c
+ }
+ return 1, nil
+}
+
+var atextChars = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ" +
+ "abcdefghijklmnopqrstuvwxyz" +
+ "0123456789" +
+ "!#$%&'*+-/=?^_`{|}~")
+
+// isAtext returns true if c is an RFC 5322 atext character.
+// If dot is true, period is included.
+func isAtext(c byte, dot bool) bool {
+ if dot && c == '.' {
+ return true
+ }
+ return bytes.IndexByte(atextChars, c) >= 0
+}
+
+// isQtext returns true if c is an RFC 5322 qtest character.
+func isQtext(c byte) bool {
+ // Printable US-ASCII, excluding backslash or quote.
+ if c == '\\' || c == '"' {
+ return false
+ }
+ return '!' <= c && c <= '~'
+}
+
+// isVchar returns true if c is an RFC 5322 VCHAR character.
+func isVchar(c byte) bool {
+ // Visible (printing) characters.
+ return '!' <= c && c <= '~'
+}
diff --git a/src/pkg/net/mail/message_test.go b/src/pkg/net/mail/message_test.go
new file mode 100644
index 000000000..671ff2efa
--- /dev/null
+++ b/src/pkg/net/mail/message_test.go
@@ -0,0 +1,261 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package mail
+
+import (
+ "bytes"
+ "io/ioutil"
+ "reflect"
+ "testing"
+ "time"
+)
+
+var parseTests = []struct {
+ in string
+ header Header
+ body string
+}{
+ {
+ // RFC 5322, Appendix A.1.1
+ in: `From: John Doe <jdoe@machine.example>
+To: Mary Smith <mary@example.net>
+Subject: Saying Hello
+Date: Fri, 21 Nov 1997 09:55:06 -0600
+Message-ID: <1234@local.machine.example>
+
+This is a message just to say hello.
+So, "Hello".
+`,
+ header: Header{
+ "From": []string{"John Doe <jdoe@machine.example>"},
+ "To": []string{"Mary Smith <mary@example.net>"},
+ "Subject": []string{"Saying Hello"},
+ "Date": []string{"Fri, 21 Nov 1997 09:55:06 -0600"},
+ "Message-Id": []string{"<1234@local.machine.example>"},
+ },
+ body: "This is a message just to say hello.\nSo, \"Hello\".\n",
+ },
+}
+
+func TestParsing(t *testing.T) {
+ for i, test := range parseTests {
+ msg, err := ReadMessage(bytes.NewBuffer([]byte(test.in)))
+ if err != nil {
+ t.Errorf("test #%d: Failed parsing message: %v", i, err)
+ continue
+ }
+ if !headerEq(msg.Header, test.header) {
+ t.Errorf("test #%d: Incorrectly parsed message header.\nGot:\n%+v\nWant:\n%+v",
+ i, msg.Header, test.header)
+ }
+ body, err := ioutil.ReadAll(msg.Body)
+ if err != nil {
+ t.Errorf("test #%d: Failed reading body: %v", i, err)
+ continue
+ }
+ bodyStr := string(body)
+ if bodyStr != test.body {
+ t.Errorf("test #%d: Incorrectly parsed message body.\nGot:\n%+v\nWant:\n%+v",
+ i, bodyStr, test.body)
+ }
+ }
+}
+
+func headerEq(a, b Header) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for k, as := range a {
+ bs, ok := b[k]
+ if !ok {
+ return false
+ }
+ if !reflect.DeepEqual(as, bs) {
+ return false
+ }
+ }
+ return true
+}
+
+func TestDateParsing(t *testing.T) {
+ tests := []struct {
+ dateStr string
+ exp time.Time
+ }{
+ // RFC 5322, Appendix A.1.1
+ {
+ "Fri, 21 Nov 1997 09:55:06 -0600",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ },
+ // RFC5322, Appendix A.6.2
+ // Obsolete date.
+ {
+ "21 Nov 97 09:55:06 GMT",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("GMT", 0)),
+ },
+ }
+ for _, test := range tests {
+ hdr := Header{
+ "Date": []string{test.dateStr},
+ }
+ date, err := hdr.Date()
+ if err != nil {
+ t.Errorf("Failed parsing %q: %v", test.dateStr, err)
+ continue
+ }
+ if !date.Equal(test.exp) {
+ t.Errorf("Parse of %q: got %+v, want %+v", test.dateStr, date, test.exp)
+ }
+ }
+}
+
+func TestAddressParsing(t *testing.T) {
+ tests := []struct {
+ addrsStr string
+ exp []*Address
+ }{
+ // Bare address
+ {
+ `jdoe@machine.example`,
+ []*Address{{
+ Address: "jdoe@machine.example",
+ }},
+ },
+ // RFC 5322, Appendix A.1.1
+ {
+ `John Doe <jdoe@machine.example>`,
+ []*Address{{
+ Name: "John Doe",
+ Address: "jdoe@machine.example",
+ }},
+ },
+ // RFC 5322, Appendix A.1.2
+ {
+ `"Joe Q. Public" <john.q.public@example.com>`,
+ []*Address{{
+ Name: "Joe Q. Public",
+ Address: "john.q.public@example.com",
+ }},
+ },
+ {
+ `Mary Smith <mary@x.test>, jdoe@example.org, Who? <one@y.test>`,
+ []*Address{
+ {
+ Name: "Mary Smith",
+ Address: "mary@x.test",
+ },
+ {
+ Address: "jdoe@example.org",
+ },
+ {
+ Name: "Who?",
+ Address: "one@y.test",
+ },
+ },
+ },
+ {
+ `<boss@nil.test>, "Giant; \"Big\" Box" <sysservices@example.net>`,
+ []*Address{
+ {
+ Address: "boss@nil.test",
+ },
+ {
+ Name: `Giant; "Big" Box`,
+ Address: "sysservices@example.net",
+ },
+ },
+ },
+ // RFC 5322, Appendix A.1.3
+ // TODO(dsymonds): Group addresses.
+
+ // RFC 2047 "Q"-encoded ISO-8859-1 address.
+ {
+ `=?iso-8859-1?q?J=F6rg_Doe?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg Doe`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // RFC 2047 "Q"-encoded UTF-8 address.
+ {
+ `=?utf-8?q?J=C3=B6rg_Doe?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg Doe`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // RFC 2047, Section 8.
+ {
+ `=?ISO-8859-1?Q?Andr=E9?= Pirard <PIRARD@vm1.ulg.ac.be>`,
+ []*Address{
+ {
+ Name: `André Pirard`,
+ Address: "PIRARD@vm1.ulg.ac.be",
+ },
+ },
+ },
+ // Custom example of RFC 2047 "B"-encoded ISO-8859-1 address.
+ {
+ `=?ISO-8859-1?B?SvZyZw==?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // Custom example of RFC 2047 "B"-encoded UTF-8 address.
+ {
+ `=?UTF-8?B?SsO2cmc=?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ addrs, err := newAddrParser(test.addrsStr).parseAddressList()
+ if err != nil {
+ t.Errorf("Failed parsing %q: %v", test.addrsStr, err)
+ continue
+ }
+ if !reflect.DeepEqual(addrs, test.exp) {
+ t.Errorf("Parse of %q: got %+v, want %+v", test.addrsStr, addrs, test.exp)
+ }
+ }
+}
+
+func TestAddressFormatting(t *testing.T) {
+ tests := []struct {
+ addr *Address
+ exp string
+ }{
+ {
+ &Address{Address: "bob@example.com"},
+ "<bob@example.com>",
+ },
+ {
+ &Address{Name: "Bob", Address: "bob@example.com"},
+ `"Bob" <bob@example.com>`,
+ },
+ {
+ // note the ö (o with an umlaut)
+ &Address{Name: "Böb", Address: "bob@example.com"},
+ `=?utf-8?q?B=C3=B6b?= <bob@example.com>`,
+ },
+ }
+ for _, test := range tests {
+ s := test.addr.String()
+ if s != test.exp {
+ t.Errorf("Address%+v.String() = %v, want %v", *test.addr, s, test.exp)
+ }
+ }
+}
diff --git a/src/pkg/net/multicast_test.go b/src/pkg/net/multicast_test.go
index a66250c84..183d5a8ab 100644
--- a/src/pkg/net/multicast_test.go
+++ b/src/pkg/net/multicast_test.go
@@ -13,7 +13,7 @@ import (
var multicast = flag.Bool("multicast", false, "enable multicast tests")
-var joinAndLeaveGroupUDPTests = []struct {
+var multicastUDPTests = []struct {
net string
laddr IP
gaddr IP
@@ -32,8 +32,8 @@ var joinAndLeaveGroupUDPTests = []struct {
{"udp6", IPv6unspecified, ParseIP("ff0e::114"), (FlagUp | FlagLoopback), true},
}
-func TestJoinAndLeaveGroupUDP(t *testing.T) {
- if runtime.GOOS == "windows" {
+func TestMulticastUDP(t *testing.T) {
+ if runtime.GOOS == "plan9" || runtime.GOOS == "windows" {
return
}
if !*multicast {
@@ -41,7 +41,7 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) {
return
}
- for _, tt := range joinAndLeaveGroupUDPTests {
+ for _, tt := range multicastUDPTests {
var (
ifi *Interface
found bool
@@ -51,7 +51,7 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) {
}
ift, err := Interfaces()
if err != nil {
- t.Fatalf("Interfaces() failed: %v", err)
+ t.Fatalf("Interfaces failed: %v", err)
}
for _, x := range ift {
if x.Flags&tt.flags == tt.flags {
@@ -65,15 +65,20 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) {
}
c, err := ListenUDP(tt.net, &UDPAddr{IP: tt.laddr})
if err != nil {
- t.Fatal(err)
+ t.Fatalf("ListenUDP failed: %v", err)
}
defer c.Close()
if err := c.JoinGroup(ifi, tt.gaddr); err != nil {
- t.Fatal(err)
+ t.Fatalf("JoinGroup failed: %v", err)
+ }
+ if !tt.ipv6 {
+ testIPv4MulticastSocketOptions(t, c.fd, ifi)
+ } else {
+ testIPv6MulticastSocketOptions(t, c.fd, ifi)
}
ifmat, err := ifi.MulticastAddrs()
if err != nil {
- t.Fatalf("MulticastAddrs() failed: %v", err)
+ t.Fatalf("MulticastAddrs failed: %v", err)
}
for _, ifma := range ifmat {
if ifma.(*IPAddr).IP.Equal(tt.gaddr) {
@@ -85,7 +90,114 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) {
t.Fatalf("%q not found in RIB", tt.gaddr.String())
}
if err := c.LeaveGroup(ifi, tt.gaddr); err != nil {
- t.Fatal(err)
+ t.Fatalf("LeaveGroup failed: %v", err)
+ }
+ }
+}
+
+func TestSimpleMulticastUDP(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ return
+ }
+ if !*multicast {
+ t.Logf("test disabled; use --multicast to enable")
+ return
+ }
+
+ for _, tt := range multicastUDPTests {
+ var ifi *Interface
+ if tt.ipv6 {
+ continue
+ }
+ tt.flags = FlagUp | FlagMulticast
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatalf("Interfaces failed: %v", err)
+ }
+ for _, x := range ift {
+ if x.Flags&tt.flags == tt.flags {
+ ifi = &x
+ break
+ }
+ }
+ if ifi == nil {
+ t.Logf("an appropriate multicast interface not found")
+ return
+ }
+ c, err := ListenUDP(tt.net, &UDPAddr{IP: tt.laddr})
+ if err != nil {
+ t.Fatalf("ListenUDP failed: %v", err)
+ }
+ defer c.Close()
+ if err := c.JoinGroup(ifi, tt.gaddr); err != nil {
+ t.Fatalf("JoinGroup failed: %v", err)
+ }
+ if err := c.LeaveGroup(ifi, tt.gaddr); err != nil {
+ t.Fatalf("LeaveGroup failed: %v", err)
}
}
}
+
+func testIPv4MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) {
+ ifmc, err := ipv4MulticastInterface(fd)
+ if err != nil {
+ t.Fatalf("ipv4MulticastInterface failed: %v", err)
+ }
+ t.Logf("IPv4 multicast interface: %v", ifmc)
+ err = setIPv4MulticastInterface(fd, ifi)
+ if err != nil {
+ t.Fatalf("setIPv4MulticastInterface failed: %v", err)
+ }
+
+ ttl, err := ipv4MulticastTTL(fd)
+ if err != nil {
+ t.Fatalf("ipv4MulticastTTL failed: %v", err)
+ }
+ t.Logf("IPv4 multicast TTL: %v", ttl)
+ err = setIPv4MulticastTTL(fd, 1)
+ if err != nil {
+ t.Fatalf("setIPv4MulticastTTL failed: %v", err)
+ }
+
+ loop, err := ipv4MulticastLoopback(fd)
+ if err != nil {
+ t.Fatalf("ipv4MulticastLoopback failed: %v", err)
+ }
+ t.Logf("IPv4 multicast loopback: %v", loop)
+ err = setIPv4MulticastLoopback(fd, false)
+ if err != nil {
+ t.Fatalf("setIPv4MulticastLoopback failed: %v", err)
+ }
+}
+
+func testIPv6MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) {
+ ifmc, err := ipv6MulticastInterface(fd)
+ if err != nil {
+ t.Fatalf("ipv6MulticastInterface failed: %v", err)
+ }
+ t.Logf("IPv6 multicast interface: %v", ifmc)
+ err = setIPv6MulticastInterface(fd, ifi)
+ if err != nil {
+ t.Fatalf("setIPv6MulticastInterface failed: %v", err)
+ }
+
+ hoplim, err := ipv6MulticastHopLimit(fd)
+ if err != nil {
+ t.Fatalf("ipv6MulticastHopLimit failed: %v", err)
+ }
+ t.Logf("IPv6 multicast hop limit: %v", hoplim)
+ err = setIPv6MulticastHopLimit(fd, 1)
+ if err != nil {
+ t.Fatalf("setIPv6MulticastHopLimit failed: %v", err)
+ }
+
+ loop, err := ipv6MulticastLoopback(fd)
+ if err != nil {
+ t.Fatalf("ipv6MulticastLoopback failed: %v", err)
+ }
+ t.Logf("IPv6 multicast loopback: %v", loop)
+ err = setIPv6MulticastLoopback(fd, false)
+ if err != nil {
+ t.Fatalf("setIPv6MulticastLoopback failed: %v", err)
+ }
+}
diff --git a/src/pkg/net/net.go b/src/pkg/net/net.go
index 5c84d3434..609fee242 100644
--- a/src/pkg/net/net.go
+++ b/src/pkg/net/net.go
@@ -9,7 +9,10 @@ package net
// TODO(rsc):
// support for raw ethernet sockets
-import "os"
+import (
+ "errors"
+ "time"
+)
// Addr represents a network end point address.
type Addr interface {
@@ -21,16 +24,16 @@ type Addr interface {
type Conn interface {
// Read reads data from the connection.
// Read can be made to time out and return a net.Error with Timeout() == true
- // after a fixed time limit; see SetTimeout and SetReadTimeout.
- Read(b []byte) (n int, err os.Error)
+ // after a fixed time limit; see SetDeadline and SetReadDeadline.
+ Read(b []byte) (n int, err error)
// Write writes data to the connection.
// Write can be made to time out and return a net.Error with Timeout() == true
- // after a fixed time limit; see SetTimeout and SetWriteTimeout.
- Write(b []byte) (n int, err os.Error)
+ // after a fixed time limit; see SetDeadline and SetWriteDeadline.
+ Write(b []byte) (n int, err error)
// Close closes the connection.
- Close() os.Error
+ Close() error
// LocalAddr returns the local network address.
LocalAddr() Addr
@@ -38,26 +41,28 @@ type Conn interface {
// RemoteAddr returns the remote network address.
RemoteAddr() Addr
- // SetTimeout sets the read and write deadlines associated
+ // SetDeadline sets the read and write deadlines associated
// with the connection.
- SetTimeout(nsec int64) os.Error
-
- // SetReadTimeout sets the time (in nanoseconds) that
- // Read will wait for data before returning an error with Timeout() == true.
- // Setting nsec == 0 (the default) disables the deadline.
- SetReadTimeout(nsec int64) os.Error
-
- // SetWriteTimeout sets the time (in nanoseconds) that
- // Write will wait to send its data before returning an error with Timeout() == true.
- // Setting nsec == 0 (the default) disables the deadline.
+ SetDeadline(t time.Time) error
+
+ // SetReadDeadline sets the deadline for all Read calls to return.
+ // If the deadline is reached, Read will fail with a timeout
+ // (see type Error) instead of blocking.
+ // A zero value for t means Read will not time out.
+ SetReadDeadline(t time.Time) error
+
+ // SetWriteDeadline sets the deadline for all Write calls to return.
+ // If the deadline is reached, Write will fail with a timeout
+ // (see type Error) instead of blocking.
+ // A zero value for t means Write will not time out.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
- SetWriteTimeout(nsec int64) os.Error
+ SetWriteDeadline(t time.Time) error
}
// An Error represents a network error.
type Error interface {
- os.Error
+ error
Timeout() bool // Is the error a timeout?
Temporary() bool // Is the error temporary?
}
@@ -70,61 +75,63 @@ type PacketConn interface {
// was on the packet.
// ReadFrom can be made to time out and return
// an error with Timeout() == true after a fixed time limit;
- // see SetTimeout and SetReadTimeout.
- ReadFrom(b []byte) (n int, addr Addr, err os.Error)
+ // see SetDeadline and SetReadDeadline.
+ ReadFrom(b []byte) (n int, addr Addr, err error)
// WriteTo writes a packet with payload b to addr.
// WriteTo can be made to time out and return
// an error with Timeout() == true after a fixed time limit;
- // see SetTimeout and SetWriteTimeout.
+ // see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
- WriteTo(b []byte, addr Addr) (n int, err os.Error)
+ WriteTo(b []byte, addr Addr) (n int, err error)
// Close closes the connection.
- Close() os.Error
+ Close() error
// LocalAddr returns the local network address.
LocalAddr() Addr
- // SetTimeout sets the read and write deadlines associated
+ // SetDeadline sets the read and write deadlines associated
// with the connection.
- SetTimeout(nsec int64) os.Error
-
- // SetReadTimeout sets the time (in nanoseconds) that
- // Read will wait for data before returning an error with Timeout() == true.
- // Setting nsec == 0 (the default) disables the deadline.
- SetReadTimeout(nsec int64) os.Error
-
- // SetWriteTimeout sets the time (in nanoseconds) that
- // Write will wait to send its data before returning an error with Timeout() == true.
- // Setting nsec == 0 (the default) disables the deadline.
+ SetDeadline(t time.Time) error
+
+ // SetReadDeadline sets the deadline for all Read calls to return.
+ // If the deadline is reached, Read will fail with a timeout
+ // (see type Error) instead of blocking.
+ // A zero value for t means Read will not time out.
+ SetReadDeadline(t time.Time) error
+
+ // SetWriteDeadline sets the deadline for all Write calls to return.
+ // If the deadline is reached, Write will fail with a timeout
+ // (see type Error) instead of blocking.
+ // A zero value for t means Write will not time out.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
- SetWriteTimeout(nsec int64) os.Error
+ SetWriteDeadline(t time.Time) error
}
// A Listener is a generic network listener for stream-oriented protocols.
type Listener interface {
// Accept waits for and returns the next connection to the listener.
- Accept() (c Conn, err os.Error)
+ Accept() (c Conn, err error)
// Close closes the listener.
- Close() os.Error
+ Close() error
// Addr returns the listener's network address.
Addr() Addr
}
-var errMissingAddress = os.NewError("missing address")
+var errMissingAddress = errors.New("missing address")
type OpError struct {
- Op string
- Net string
- Addr Addr
- Error os.Error
+ Op string
+ Net string
+ Addr Addr
+ Err error
}
-func (e *OpError) String() string {
+func (e *OpError) Error() string {
if e == nil {
return "<nil>"
}
@@ -135,7 +142,7 @@ func (e *OpError) String() string {
if e.Addr != nil {
s += " " + e.Addr.String()
}
- s += ": " + e.Error.String()
+ s += ": " + e.Err.Error()
return s
}
@@ -144,7 +151,7 @@ type temporary interface {
}
func (e *OpError) Temporary() bool {
- t, ok := e.Error.(temporary)
+ t, ok := e.Err.(temporary)
return ok && t.Temporary()
}
@@ -153,20 +160,28 @@ type timeout interface {
}
func (e *OpError) Timeout() bool {
- t, ok := e.Error.(timeout)
+ t, ok := e.Err.(timeout)
return ok && t.Timeout()
}
+type timeoutError struct{}
+
+func (e *timeoutError) Error() string { return "i/o timeout" }
+func (e *timeoutError) Timeout() bool { return true }
+func (e *timeoutError) Temporary() bool { return true }
+
+var errTimeout error = &timeoutError{}
+
type AddrError struct {
- Error string
- Addr string
+ Err string
+ Addr string
}
-func (e *AddrError) String() string {
+func (e *AddrError) Error() string {
if e == nil {
return "<nil>"
}
- s := e.Error
+ s := e.Err
if e.Addr != "" {
s += " " + e.Addr
}
@@ -183,6 +198,6 @@ func (e *AddrError) Timeout() bool {
type UnknownNetworkError string
-func (e UnknownNetworkError) String() string { return "unknown network " + string(e) }
+func (e UnknownNetworkError) Error() string { return "unknown network " + string(e) }
func (e UnknownNetworkError) Temporary() bool { return false }
func (e UnknownNetworkError) Timeout() bool { return false }
diff --git a/src/pkg/net/net_test.go b/src/pkg/net/net_test.go
index 698a84527..0dc86698e 100644
--- a/src/pkg/net/net_test.go
+++ b/src/pkg/net/net_test.go
@@ -6,7 +6,9 @@ package net
import (
"flag"
+ "io"
"regexp"
+ "runtime"
"testing"
)
@@ -61,6 +63,8 @@ var dialErrorTests = []DialErrorTest{
},
}
+var duplicateErrorPattern = `dial (.*) dial (.*)`
+
func TestDialError(t *testing.T) {
if !*runErrorTest {
t.Logf("test disabled; use --run_error_test to enable")
@@ -75,11 +79,15 @@ func TestDialError(t *testing.T) {
t.Errorf("#%d: nil error, want match for %#q", i, tt.Pattern)
continue
}
- s := e.String()
+ s := e.Error()
match, _ := regexp.MatchString(tt.Pattern, s)
if !match {
t.Errorf("#%d: %q, want match for %#q", i, s, tt.Pattern)
}
+ match, _ = regexp.MatchString(duplicateErrorPattern, s)
+ if match {
+ t.Errorf("#%d: %q, duplicate error return from Dial", i, s)
+ }
}
}
@@ -111,11 +119,57 @@ func TestReverseAddress(t *testing.T) {
if len(tt.ErrPrefix) == 0 && e != nil {
t.Errorf("#%d: expected <nil>, got %q (error)", i, e)
}
- if e != nil && e.(*DNSError).Error != tt.ErrPrefix {
- t.Errorf("#%d: expected %q, got %q (mismatched error)", i, tt.ErrPrefix, e.(*DNSError).Error)
+ if e != nil && e.(*DNSError).Err != tt.ErrPrefix {
+ t.Errorf("#%d: expected %q, got %q (mismatched error)", i, tt.ErrPrefix, e.(*DNSError).Err)
}
if a != tt.Reverse {
t.Errorf("#%d: expected %q, got %q (reverse address)", i, tt.Reverse, a)
}
}
}
+
+func TestShutdown(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ return
+ }
+ l, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ if l, err = Listen("tcp6", "[::1]:0"); err != nil {
+ t.Fatalf("ListenTCP on :0: %v", err)
+ }
+ }
+
+ go func() {
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatalf("Accept: %v", err)
+ }
+ var buf [10]byte
+ n, err := c.Read(buf[:])
+ if n != 0 || err != io.EOF {
+ t.Fatalf("server Read = %d, %v; want 0, io.EOF", n, err)
+ }
+ c.Write([]byte("response"))
+ c.Close()
+ }()
+
+ c, err := Dial("tcp", l.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer c.Close()
+
+ err = c.(*TCPConn).CloseWrite()
+ if err != nil {
+ t.Fatalf("CloseWrite: %v", err)
+ }
+ var buf [10]byte
+ n, err := c.Read(buf[:])
+ if err != nil {
+ t.Fatalf("client Read: %d, %v", n, err)
+ }
+ got := string(buf[:n])
+ if got != "response" {
+ t.Errorf("read = %q, want \"response\"", got)
+ }
+}
diff --git a/src/pkg/net/newpollserver.go b/src/pkg/net/newpollserver.go
index 3c9a6da53..a410bb6ce 100644
--- a/src/pkg/net/newpollserver.go
+++ b/src/pkg/net/newpollserver.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd
+// +build darwin freebsd linux netbsd openbsd
package net
@@ -11,18 +11,17 @@ import (
"syscall"
)
-func newPollServer() (s *pollServer, err os.Error) {
+func newPollServer() (s *pollServer, err error) {
s = new(pollServer)
s.cr = make(chan *netFD, 1)
s.cw = make(chan *netFD, 1)
if s.pr, s.pw, err = os.Pipe(); err != nil {
return nil, err
}
- var e int
- if e = syscall.SetNonblock(s.pr.Fd(), true); e != 0 {
+ if err = syscall.SetNonblock(s.pr.Fd(), true); err != nil {
goto Errno
}
- if e = syscall.SetNonblock(s.pw.Fd(), true); e != 0 {
+ if err = syscall.SetNonblock(s.pw.Fd(), true); err != nil {
goto Errno
}
if s.poll, err = newpollster(); err != nil {
@@ -37,7 +36,7 @@ func newPollServer() (s *pollServer, err os.Error) {
return s, nil
Errno:
- err = &os.PathError{"setnonblock", s.pr.Name(), os.Errno(e)}
+ err = &os.PathError{"setnonblock", s.pr.Name(), err}
Error:
s.pr.Close()
s.pw.Close()
diff --git a/src/pkg/net/parse.go b/src/pkg/net/parse.go
index 0d30a7ac6..4c4200a49 100644
--- a/src/pkg/net/parse.go
+++ b/src/pkg/net/parse.go
@@ -54,7 +54,7 @@ func (f *file) readLine() (s string, ok bool) {
if n >= 0 {
f.data = f.data[0 : ln+n]
}
- if err == os.EOF {
+ if err == io.EOF {
f.atEOF = true
}
}
@@ -62,7 +62,7 @@ func (f *file) readLine() (s string, ok bool) {
return
}
-func open(name string) (*file, os.Error) {
+func open(name string) (*file, error) {
fd, err := os.Open(name)
if err != nil {
return nil, err
diff --git a/src/pkg/net/parse_test.go b/src/pkg/net/parse_test.go
index 8d51eba18..dfbaba4d9 100644
--- a/src/pkg/net/parse_test.go
+++ b/src/pkg/net/parse_test.go
@@ -7,8 +7,8 @@ package net
import (
"bufio"
"os"
- "testing"
"runtime"
+ "testing"
)
func TestReadLine(t *testing.T) {
diff --git a/src/pkg/net/pipe.go b/src/pkg/net/pipe.go
index c0bbd356b..f1a2eca4e 100644
--- a/src/pkg/net/pipe.go
+++ b/src/pkg/net/pipe.go
@@ -1,8 +1,13 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
package net
import (
+ "errors"
"io"
- "os"
+ "time"
)
// Pipe creates a synchronous, in-memory, full duplex
@@ -32,7 +37,7 @@ func (pipeAddr) String() string {
return "pipe"
}
-func (p *pipe) Close() os.Error {
+func (p *pipe) Close() error {
err := p.PipeReader.Close()
err1 := p.PipeWriter.Close()
if err == nil {
@@ -49,14 +54,14 @@ func (p *pipe) RemoteAddr() Addr {
return pipeAddr(0)
}
-func (p *pipe) SetTimeout(nsec int64) os.Error {
- return os.NewError("net.Pipe does not support timeouts")
+func (p *pipe) SetDeadline(t time.Time) error {
+ return errors.New("net.Pipe does not support deadlines")
}
-func (p *pipe) SetReadTimeout(nsec int64) os.Error {
- return os.NewError("net.Pipe does not support timeouts")
+func (p *pipe) SetReadDeadline(t time.Time) error {
+ return errors.New("net.Pipe does not support deadlines")
}
-func (p *pipe) SetWriteTimeout(nsec int64) os.Error {
- return os.NewError("net.Pipe does not support timeouts")
+func (p *pipe) SetWriteDeadline(t time.Time) error {
+ return errors.New("net.Pipe does not support deadlines")
}
diff --git a/src/pkg/net/pipe_test.go b/src/pkg/net/pipe_test.go
index 7e4c6db44..afe4f2408 100644
--- a/src/pkg/net/pipe_test.go
+++ b/src/pkg/net/pipe_test.go
@@ -7,7 +7,6 @@ package net
import (
"bytes"
"io"
- "os"
"testing"
)
@@ -22,7 +21,7 @@ func checkWrite(t *testing.T, w io.Writer, data []byte, c chan int) {
c <- 0
}
-func checkRead(t *testing.T, r io.Reader, data []byte, wantErr os.Error) {
+func checkRead(t *testing.T, r io.Reader, data []byte, wantErr error) {
buf := make([]byte, len(data)+10)
n, err := r.Read(buf)
if err != wantErr {
@@ -52,6 +51,6 @@ func TestPipe(t *testing.T) {
checkRead(t, srv, []byte("a third line"), nil)
<-c
go srv.Close()
- checkRead(t, cli, nil, os.EOF)
+ checkRead(t, cli, nil, io.EOF)
cli.Close()
}
diff --git a/src/pkg/net/port.go b/src/pkg/net/port.go
index a8ca60c60..16780da11 100644
--- a/src/pkg/net/port.go
+++ b/src/pkg/net/port.go
@@ -2,19 +2,16 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd
+// +build darwin freebsd linux netbsd openbsd
// Read system port mappings from /etc/services
package net
-import (
- "os"
- "sync"
-)
+import "sync"
var services map[string]map[string]int
-var servicesError os.Error
+var servicesError error
var onceReadServices sync.Once
func readServices() {
@@ -53,7 +50,7 @@ func readServices() {
}
// goLookupPort is the native Go implementation of LookupPort.
-func goLookupPort(network, service string) (port int, err os.Error) {
+func goLookupPort(network, service string) (port int, err error) {
onceReadServices.Do(readServices)
switch network {
diff --git a/src/pkg/net/rpc/Makefile b/src/pkg/net/rpc/Makefile
new file mode 100644
index 000000000..0e6c9846b
--- /dev/null
+++ b/src/pkg/net/rpc/Makefile
@@ -0,0 +1,13 @@
+# Copyright 2009 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../Make.inc
+
+TARG=net/rpc
+GOFILES=\
+ client.go\
+ debug.go\
+ server.go\
+
+include ../../../Make.pkg
diff --git a/src/pkg/net/rpc/client.go b/src/pkg/net/rpc/client.go
new file mode 100644
index 000000000..abc1e59cd
--- /dev/null
+++ b/src/pkg/net/rpc/client.go
@@ -0,0 +1,288 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+import (
+ "bufio"
+ "encoding/gob"
+ "errors"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "sync"
+)
+
+// ServerError represents an error that has been returned from
+// the remote side of the RPC connection.
+type ServerError string
+
+func (e ServerError) Error() string {
+ return string(e)
+}
+
+var ErrShutdown = errors.New("connection is shut down")
+
+// Call represents an active RPC.
+type Call struct {
+ ServiceMethod string // The name of the service and method to call.
+ Args interface{} // The argument to the function (*struct).
+ Reply interface{} // The reply from the function (*struct).
+ Error error // After completion, the error status.
+ Done chan *Call // Strobes when call is complete; value is the error status.
+ seq uint64
+}
+
+// Client represents an RPC Client.
+// There may be multiple outstanding Calls associated
+// with a single Client.
+type Client struct {
+ mutex sync.Mutex // protects pending, seq, request
+ sending sync.Mutex
+ request Request
+ seq uint64
+ codec ClientCodec
+ pending map[uint64]*Call
+ closing bool
+ shutdown bool
+}
+
+// A ClientCodec implements writing of RPC requests and
+// reading of RPC responses for the client side of an RPC session.
+// The client calls WriteRequest to write a request to the connection
+// and calls ReadResponseHeader and ReadResponseBody in pairs
+// to read responses. The client calls Close when finished with the
+// connection. ReadResponseBody may be called with a nil
+// argument to force the body of the response to be read and then
+// discarded.
+type ClientCodec interface {
+ WriteRequest(*Request, interface{}) error
+ ReadResponseHeader(*Response) error
+ ReadResponseBody(interface{}) error
+
+ Close() error
+}
+
+func (client *Client) send(c *Call) {
+ // Register this call.
+ client.mutex.Lock()
+ if client.shutdown {
+ c.Error = ErrShutdown
+ client.mutex.Unlock()
+ c.done()
+ return
+ }
+ c.seq = client.seq
+ client.seq++
+ client.pending[c.seq] = c
+ client.mutex.Unlock()
+
+ // Encode and send the request.
+ client.sending.Lock()
+ defer client.sending.Unlock()
+ client.request.Seq = c.seq
+ client.request.ServiceMethod = c.ServiceMethod
+ if err := client.codec.WriteRequest(&client.request, c.Args); err != nil {
+ c.Error = err
+ c.done()
+ }
+}
+
+func (client *Client) input() {
+ var err error
+ var response Response
+ for err == nil {
+ response = Response{}
+ err = client.codec.ReadResponseHeader(&response)
+ if err != nil {
+ if err == io.EOF && !client.closing {
+ err = io.ErrUnexpectedEOF
+ }
+ break
+ }
+ seq := response.Seq
+ client.mutex.Lock()
+ c := client.pending[seq]
+ delete(client.pending, seq)
+ client.mutex.Unlock()
+
+ if response.Error == "" {
+ err = client.codec.ReadResponseBody(c.Reply)
+ if err != nil {
+ c.Error = errors.New("reading body " + err.Error())
+ }
+ } else {
+ // We've got an error response. Give this to the request;
+ // any subsequent requests will get the ReadResponseBody
+ // error if there is one.
+ c.Error = ServerError(response.Error)
+ err = client.codec.ReadResponseBody(nil)
+ if err != nil {
+ err = errors.New("reading error body: " + err.Error())
+ }
+ }
+ c.done()
+ }
+ // Terminate pending calls.
+ client.mutex.Lock()
+ client.shutdown = true
+ for _, call := range client.pending {
+ call.Error = err
+ call.done()
+ }
+ client.mutex.Unlock()
+ if err != io.EOF || !client.closing {
+ log.Println("rpc: client protocol error:", err)
+ }
+}
+
+func (call *Call) done() {
+ select {
+ case call.Done <- call:
+ // ok
+ default:
+ // We don't want to block here. It is the caller's responsibility to make
+ // sure the channel has enough buffer space. See comment in Go().
+ log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
+ }
+}
+
+// NewClient returns a new Client to handle requests to the
+// set of services at the other end of the connection.
+// It adds a buffer to the write side of the connection so
+// the header and payload are sent as a unit.
+func NewClient(conn io.ReadWriteCloser) *Client {
+ encBuf := bufio.NewWriter(conn)
+ client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
+ return NewClientWithCodec(client)
+}
+
+// NewClientWithCodec is like NewClient but uses the specified
+// codec to encode requests and decode responses.
+func NewClientWithCodec(codec ClientCodec) *Client {
+ client := &Client{
+ codec: codec,
+ pending: make(map[uint64]*Call),
+ }
+ go client.input()
+ return client
+}
+
+type gobClientCodec struct {
+ rwc io.ReadWriteCloser
+ dec *gob.Decoder
+ enc *gob.Encoder
+ encBuf *bufio.Writer
+}
+
+func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) {
+ if err = c.enc.Encode(r); err != nil {
+ return
+ }
+ if err = c.enc.Encode(body); err != nil {
+ return
+ }
+ return c.encBuf.Flush()
+}
+
+func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
+ return c.dec.Decode(r)
+}
+
+func (c *gobClientCodec) ReadResponseBody(body interface{}) error {
+ return c.dec.Decode(body)
+}
+
+func (c *gobClientCodec) Close() error {
+ return c.rwc.Close()
+}
+
+// DialHTTP connects to an HTTP RPC server at the specified network address
+// listening on the default HTTP RPC path.
+func DialHTTP(network, address string) (*Client, error) {
+ return DialHTTPPath(network, address, DefaultRPCPath)
+}
+
+// DialHTTPPath connects to an HTTP RPC server
+// at the specified network address and path.
+func DialHTTPPath(network, address, path string) (*Client, error) {
+ var err error
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
+
+ // Require successful HTTP response
+ // before switching to RPC protocol.
+ resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
+ if err == nil && resp.Status == connected {
+ return NewClient(conn), nil
+ }
+ if err == nil {
+ err = errors.New("unexpected HTTP response: " + resp.Status)
+ }
+ conn.Close()
+ return nil, &net.OpError{"dial-http", network + " " + address, nil, err}
+}
+
+// Dial connects to an RPC server at the specified network address.
+func Dial(network, address string) (*Client, error) {
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(conn), nil
+}
+
+func (client *Client) Close() error {
+ client.mutex.Lock()
+ if client.shutdown || client.closing {
+ client.mutex.Unlock()
+ return ErrShutdown
+ }
+ client.closing = true
+ client.mutex.Unlock()
+ return client.codec.Close()
+}
+
+// Go invokes the function asynchronously. It returns the Call structure representing
+// the invocation. The done channel will signal when the call is complete by returning
+// the same Call object. If done is nil, Go will allocate a new channel.
+// If non-nil, done must be buffered or Go will deliberately crash.
+func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
+ call := new(Call)
+ call.ServiceMethod = serviceMethod
+ call.Args = args
+ call.Reply = reply
+ if done == nil {
+ done = make(chan *Call, 10) // buffered.
+ } else {
+ // If caller passes done != nil, it must arrange that
+ // done has enough buffer for the number of simultaneous
+ // RPCs that will be using that channel. If the channel
+ // is totally unbuffered, it's best not to run at all.
+ if cap(done) == 0 {
+ log.Panic("rpc: done channel is unbuffered")
+ }
+ }
+ call.Done = done
+ if client.shutdown {
+ call.Error = ErrShutdown
+ call.done()
+ return call
+ }
+ client.send(call)
+ return call
+}
+
+// Call invokes the named function, waits for it to complete, and returns its error status.
+func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error {
+ if client.shutdown {
+ return ErrShutdown
+ }
+ call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
+ return call.Error
+}
diff --git a/src/pkg/net/rpc/debug.go b/src/pkg/net/rpc/debug.go
new file mode 100644
index 000000000..663663fe9
--- /dev/null
+++ b/src/pkg/net/rpc/debug.go
@@ -0,0 +1,90 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+/*
+ Some HTML presented at http://machine:port/debug/rpc
+ Lists services, their methods, and some statistics, still rudimentary.
+*/
+
+import (
+ "fmt"
+ "net/http"
+ "sort"
+ "text/template"
+)
+
+const debugText = `<html>
+ <body>
+ <title>Services</title>
+ {{range .}}
+ <hr>
+ Service {{.Name}}
+ <hr>
+ <table>
+ <th align=center>Method</th><th align=center>Calls</th>
+ {{range .Method}}
+ <tr>
+ <td align=left font=fixed>{{.Name}}({{.Type.ArgType}}, {{.Type.ReplyType}}) error</td>
+ <td align=center>{{.Type.NumCalls}}</td>
+ </tr>
+ {{end}}
+ </table>
+ {{end}}
+ </body>
+ </html>`
+
+var debug = template.Must(template.New("RPC debug").Parse(debugText))
+
+type debugMethod struct {
+ Type *methodType
+ Name string
+}
+
+type methodArray []debugMethod
+
+type debugService struct {
+ Service *service
+ Name string
+ Method methodArray
+}
+
+type serviceArray []debugService
+
+func (s serviceArray) Len() int { return len(s) }
+func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name }
+func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+func (m methodArray) Len() int { return len(m) }
+func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name }
+func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
+
+type debugHTTP struct {
+ *Server
+}
+
+// Runs at /debug/rpc
+func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ // Build a sorted version of the data.
+ var services = make(serviceArray, len(server.serviceMap))
+ i := 0
+ server.mu.Lock()
+ for sname, service := range server.serviceMap {
+ services[i] = debugService{service, sname, make(methodArray, len(service.method))}
+ j := 0
+ for mname, method := range service.method {
+ services[i].Method[j] = debugMethod{method, mname}
+ j++
+ }
+ sort.Sort(services[i].Method)
+ i++
+ }
+ server.mu.Unlock()
+ sort.Sort(services)
+ err := debug.Execute(w, services)
+ if err != nil {
+ fmt.Fprintln(w, "rpc: error executing template:", err.Error())
+ }
+}
diff --git a/src/pkg/net/rpc/jsonrpc/Makefile b/src/pkg/net/rpc/jsonrpc/Makefile
new file mode 100644
index 000000000..c5ea5373d
--- /dev/null
+++ b/src/pkg/net/rpc/jsonrpc/Makefile
@@ -0,0 +1,12 @@
+# Copyright 2010 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../../Make.inc
+
+TARG=net/rpc/jsonrpc
+GOFILES=\
+ client.go\
+ server.go\
+
+include ../../../../Make.pkg
diff --git a/src/pkg/net/rpc/jsonrpc/all_test.go b/src/pkg/net/rpc/jsonrpc/all_test.go
new file mode 100644
index 000000000..e6c7441f0
--- /dev/null
+++ b/src/pkg/net/rpc/jsonrpc/all_test.go
@@ -0,0 +1,221 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package jsonrpc
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/rpc"
+ "testing"
+)
+
+type Args struct {
+ A, B int
+}
+
+type Reply struct {
+ C int
+}
+
+type Arith int
+
+func (t *Arith) Add(args *Args, reply *Reply) error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+func (t *Arith) Mul(args *Args, reply *Reply) error {
+ reply.C = args.A * args.B
+ return nil
+}
+
+func (t *Arith) Div(args *Args, reply *Reply) error {
+ if args.B == 0 {
+ return errors.New("divide by zero")
+ }
+ reply.C = args.A / args.B
+ return nil
+}
+
+func (t *Arith) Error(args *Args, reply *Reply) error {
+ panic("ERROR")
+}
+
+func init() {
+ rpc.Register(new(Arith))
+}
+
+func TestServer(t *testing.T) {
+ type addResp struct {
+ Id interface{} `json:"id"`
+ Result Reply `json:"result"`
+ Error interface{} `json:"error"`
+ }
+
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ // Send hand-coded requests to server, parse responses.
+ for i := 0; i < 10; i++ {
+ fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
+ var resp addResp
+ err := dec.Decode(&resp)
+ if err != nil {
+ t.Fatalf("Decode: %s", err)
+ }
+ if resp.Error != nil {
+ t.Fatalf("resp.Error: %s", resp.Error)
+ }
+ if resp.Id.(string) != string(i) {
+ t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(i))
+ }
+ if resp.Result.C != 2*i+1 {
+ t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
+ }
+ }
+
+ fmt.Fprintf(cli, "{}\n")
+ var resp addResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after empty: %s", err)
+ }
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+
+func TestClient(t *testing.T) {
+ // Assume server is okay (TestServer is above).
+ // Test client against server.
+ cli, srv := net.Pipe()
+ go ServeConn(srv)
+
+ client := NewClient(cli)
+ defer client.Close()
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Mul", args, reply)
+ if err != nil {
+ t.Errorf("Mul: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ }
+
+ // Out of order.
+ args = &Args{7, 8}
+ mulReply := new(Reply)
+ mulCall := client.Go("Arith.Mul", args, mulReply, nil)
+ addReply := new(Reply)
+ addCall := client.Go("Arith.Add", args, addReply, nil)
+
+ addCall = <-addCall.Done
+ if addCall.Error != nil {
+ t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
+ }
+ if addReply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+ }
+
+ mulCall = <-mulCall.Done
+ if mulCall.Error != nil {
+ t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
+ }
+ if mulReply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+ }
+
+ // Error test
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.Div", args, reply)
+ // expect an error: zero divide
+ if err == nil {
+ t.Error("Div: expected error")
+ } else if err.Error() != "divide by zero" {
+ t.Error("Div: expected divide by zero error; got", err)
+ }
+}
+
+func TestMalformedInput(t *testing.T) {
+ cli, srv := net.Pipe()
+ go cli.Write([]byte(`{id:1}`)) // invalid json
+ ServeConn(srv) // must return, not loop
+}
+
+func TestUnexpectedError(t *testing.T) {
+ cli, srv := myPipe()
+ go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error
+ ServeConn(srv) // must return, not loop
+}
+
+// Copied from package net.
+func myPipe() (*pipe, *pipe) {
+ r1, w1 := io.Pipe()
+ r2, w2 := io.Pipe()
+
+ return &pipe{r1, w2}, &pipe{r2, w1}
+}
+
+type pipe struct {
+ *io.PipeReader
+ *io.PipeWriter
+}
+
+type pipeAddr int
+
+func (pipeAddr) Network() string {
+ return "pipe"
+}
+
+func (pipeAddr) String() string {
+ return "pipe"
+}
+
+func (p *pipe) Close() error {
+ err := p.PipeReader.Close()
+ err1 := p.PipeWriter.Close()
+ if err == nil {
+ err = err1
+ }
+ return err
+}
+
+func (p *pipe) LocalAddr() net.Addr {
+ return pipeAddr(0)
+}
+
+func (p *pipe) RemoteAddr() net.Addr {
+ return pipeAddr(0)
+}
+
+func (p *pipe) SetTimeout(nsec int64) error {
+ return errors.New("net.Pipe does not support timeouts")
+}
+
+func (p *pipe) SetReadTimeout(nsec int64) error {
+ return errors.New("net.Pipe does not support timeouts")
+}
+
+func (p *pipe) SetWriteTimeout(nsec int64) error {
+ return errors.New("net.Pipe does not support timeouts")
+}
diff --git a/src/pkg/net/rpc/jsonrpc/client.go b/src/pkg/net/rpc/jsonrpc/client.go
new file mode 100644
index 000000000..3fa8cbf08
--- /dev/null
+++ b/src/pkg/net/rpc/jsonrpc/client.go
@@ -0,0 +1,123 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package jsonrpc implements a JSON-RPC ClientCodec and ServerCodec
+// for the rpc package.
+package jsonrpc
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "net/rpc"
+ "sync"
+)
+
+type clientCodec struct {
+ dec *json.Decoder // for reading JSON values
+ enc *json.Encoder // for writing JSON values
+ c io.Closer
+
+ // temporary work space
+ req clientRequest
+ resp clientResponse
+
+ // JSON-RPC responses include the request id but not the request method.
+ // Package rpc expects both.
+ // We save the request method in pending when sending a request
+ // and then look it up by request ID when filling out the rpc Response.
+ mutex sync.Mutex // protects pending
+ pending map[uint64]string // map request id to method name
+}
+
+// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn.
+func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
+ return &clientCodec{
+ dec: json.NewDecoder(conn),
+ enc: json.NewEncoder(conn),
+ c: conn,
+ pending: make(map[uint64]string),
+ }
+}
+
+type clientRequest struct {
+ Method string `json:"method"`
+ Params [1]interface{} `json:"params"`
+ Id uint64 `json:"id"`
+}
+
+func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) error {
+ c.mutex.Lock()
+ c.pending[r.Seq] = r.ServiceMethod
+ c.mutex.Unlock()
+ c.req.Method = r.ServiceMethod
+ c.req.Params[0] = param
+ c.req.Id = r.Seq
+ return c.enc.Encode(&c.req)
+}
+
+type clientResponse struct {
+ Id uint64 `json:"id"`
+ Result *json.RawMessage `json:"result"`
+ Error interface{} `json:"error"`
+}
+
+func (r *clientResponse) reset() {
+ r.Id = 0
+ r.Result = nil
+ r.Error = nil
+}
+
+func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error {
+ c.resp.reset()
+ if err := c.dec.Decode(&c.resp); err != nil {
+ return err
+ }
+
+ c.mutex.Lock()
+ r.ServiceMethod = c.pending[c.resp.Id]
+ delete(c.pending, c.resp.Id)
+ c.mutex.Unlock()
+
+ r.Error = ""
+ r.Seq = c.resp.Id
+ if c.resp.Error != nil {
+ x, ok := c.resp.Error.(string)
+ if !ok {
+ return fmt.Errorf("invalid error %v", c.resp.Error)
+ }
+ if x == "" {
+ x = "unspecified error"
+ }
+ r.Error = x
+ }
+ return nil
+}
+
+func (c *clientCodec) ReadResponseBody(x interface{}) error {
+ if x == nil {
+ return nil
+ }
+ return json.Unmarshal(*c.resp.Result, x)
+}
+
+func (c *clientCodec) Close() error {
+ return c.c.Close()
+}
+
+// NewClient returns a new rpc.Client to handle requests to the
+// set of services at the other end of the connection.
+func NewClient(conn io.ReadWriteCloser) *rpc.Client {
+ return rpc.NewClientWithCodec(NewClientCodec(conn))
+}
+
+// Dial connects to a JSON-RPC server at the specified network address.
+func Dial(network, address string) (*rpc.Client, error) {
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(conn), err
+}
diff --git a/src/pkg/net/rpc/jsonrpc/server.go b/src/pkg/net/rpc/jsonrpc/server.go
new file mode 100644
index 000000000..4c54553a7
--- /dev/null
+++ b/src/pkg/net/rpc/jsonrpc/server.go
@@ -0,0 +1,136 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package jsonrpc
+
+import (
+ "encoding/json"
+ "errors"
+ "io"
+ "net/rpc"
+ "sync"
+)
+
+type serverCodec struct {
+ dec *json.Decoder // for reading JSON values
+ enc *json.Encoder // for writing JSON values
+ c io.Closer
+
+ // temporary work space
+ req serverRequest
+ resp serverResponse
+
+ // JSON-RPC clients can use arbitrary json values as request IDs.
+ // Package rpc expects uint64 request IDs.
+ // We assign uint64 sequence numbers to incoming requests
+ // but save the original request ID in the pending map.
+ // When rpc responds, we use the sequence number in
+ // the response to find the original request ID.
+ mutex sync.Mutex // protects seq, pending
+ seq uint64
+ pending map[uint64]*json.RawMessage
+}
+
+// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn.
+func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
+ return &serverCodec{
+ dec: json.NewDecoder(conn),
+ enc: json.NewEncoder(conn),
+ c: conn,
+ pending: make(map[uint64]*json.RawMessage),
+ }
+}
+
+type serverRequest struct {
+ Method string `json:"method"`
+ Params *json.RawMessage `json:"params"`
+ Id *json.RawMessage `json:"id"`
+}
+
+func (r *serverRequest) reset() {
+ r.Method = ""
+ if r.Params != nil {
+ *r.Params = (*r.Params)[0:0]
+ }
+ if r.Id != nil {
+ *r.Id = (*r.Id)[0:0]
+ }
+}
+
+type serverResponse struct {
+ Id *json.RawMessage `json:"id"`
+ Result interface{} `json:"result"`
+ Error interface{} `json:"error"`
+}
+
+func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error {
+ c.req.reset()
+ if err := c.dec.Decode(&c.req); err != nil {
+ return err
+ }
+ r.ServiceMethod = c.req.Method
+
+ // JSON request id can be any JSON value;
+ // RPC package expects uint64. Translate to
+ // internal uint64 and save JSON on the side.
+ c.mutex.Lock()
+ c.seq++
+ c.pending[c.seq] = c.req.Id
+ c.req.Id = nil
+ r.Seq = c.seq
+ c.mutex.Unlock()
+
+ return nil
+}
+
+func (c *serverCodec) ReadRequestBody(x interface{}) error {
+ if x == nil {
+ return nil
+ }
+ // JSON params is array value.
+ // RPC params is struct.
+ // Unmarshal into array containing struct for now.
+ // Should think about making RPC more general.
+ var params [1]interface{}
+ params[0] = x
+ return json.Unmarshal(*c.req.Params, &params)
+}
+
+var null = json.RawMessage([]byte("null"))
+
+func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error {
+ var resp serverResponse
+ c.mutex.Lock()
+ b, ok := c.pending[r.Seq]
+ if !ok {
+ c.mutex.Unlock()
+ return errors.New("invalid sequence number in response")
+ }
+ delete(c.pending, r.Seq)
+ c.mutex.Unlock()
+
+ if b == nil {
+ // Invalid request so no id. Use JSON null.
+ b = &null
+ }
+ resp.Id = b
+ resp.Result = x
+ if r.Error == "" {
+ resp.Error = nil
+ } else {
+ resp.Error = r.Error
+ }
+ return c.enc.Encode(resp)
+}
+
+func (c *serverCodec) Close() error {
+ return c.c.Close()
+}
+
+// ServeConn runs the JSON-RPC server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+func ServeConn(conn io.ReadWriteCloser) {
+ rpc.ServeCodec(NewServerCodec(conn))
+}
diff --git a/src/pkg/net/rpc/server.go b/src/pkg/net/rpc/server.go
new file mode 100644
index 000000000..920ae9137
--- /dev/null
+++ b/src/pkg/net/rpc/server.go
@@ -0,0 +1,640 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ Package rpc provides access to the exported methods of an object across a
+ network or other I/O connection. A server registers an object, making it visible
+ as a service with the name of the type of the object. After registration, exported
+ methods of the object will be accessible remotely. A server may register multiple
+ objects (services) of different types but it is an error to register multiple
+ objects of the same type.
+
+ Only methods that satisfy these criteria will be made available for remote access;
+ other methods will be ignored:
+
+ - the method name is exported, that is, begins with an upper case letter.
+ - the method receiver is exported or local (defined in the package
+ registering the service).
+ - the method has two arguments, both exported or local types.
+ - the method's second argument is a pointer.
+ - the method has return type error.
+
+ The method's first argument represents the arguments provided by the caller; the
+ second argument represents the result parameters to be returned to the caller.
+ The method's return value, if non-nil, is passed back as a string that the client
+ sees as if created by errors.New.
+
+ The server may handle requests on a single connection by calling ServeConn. More
+ typically it will create a network listener and call Accept or, for an HTTP
+ listener, HandleHTTP and http.Serve.
+
+ A client wishing to use the service establishes a connection and then invokes
+ NewClient on the connection. The convenience function Dial (DialHTTP) performs
+ both steps for a raw network connection (an HTTP connection). The resulting
+ Client object has two methods, Call and Go, that specify the service and method to
+ call, a pointer containing the arguments, and a pointer to receive the result
+ parameters.
+
+ Call waits for the remote call to complete; Go launches the call asynchronously
+ and returns a channel that will signal completion.
+
+ Package "gob" is used to transport the data.
+
+ Here is a simple example. A server wishes to export an object of type Arith:
+
+ package server
+
+ type Args struct {
+ A, B int
+ }
+
+ type Quotient struct {
+ Quo, Rem int
+ }
+
+ type Arith int
+
+ func (t *Arith) Multiply(args *Args, reply *int) error {
+ *reply = args.A * args.B
+ return nil
+ }
+
+ func (t *Arith) Divide(args *Args, quo *Quotient) error {
+ if args.B == 0 {
+ return errors.New("divide by zero")
+ }
+ quo.Quo = args.A / args.B
+ quo.Rem = args.A % args.B
+ return nil
+ }
+
+ The server calls (for HTTP service):
+
+ arith := new(Arith)
+ rpc.Register(arith)
+ rpc.HandleHTTP()
+ l, e := net.Listen("tcp", ":1234")
+ if e != nil {
+ log.Fatal("listen error:", e)
+ }
+ go http.Serve(l, nil)
+
+ At this point, clients can see a service "Arith" with methods "Arith.Multiply" and
+ "Arith.Divide". To invoke one, a client first dials the server:
+
+ client, err := rpc.DialHTTP("tcp", serverAddress + ":1234")
+ if err != nil {
+ log.Fatal("dialing:", err)
+ }
+
+ Then it can make a remote call:
+
+ // Synchronous call
+ args := &server.Args{7,8}
+ var reply int
+ err = client.Call("Arith.Multiply", args, &reply)
+ if err != nil {
+ log.Fatal("arith error:", err)
+ }
+ fmt.Printf("Arith: %d*%d=%d", args.A, args.B, reply)
+
+ or
+
+ // Asynchronous call
+ quotient := new(Quotient)
+ divCall := client.Go("Arith.Divide", args, &quotient, nil)
+ replyCall := <-divCall.Done // will be equal to divCall
+ // check errors, print, etc.
+
+ A server implementation will often provide a simple, type-safe wrapper for the
+ client.
+*/
+package rpc
+
+import (
+ "bufio"
+ "encoding/gob"
+ "errors"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "reflect"
+ "strings"
+ "sync"
+ "unicode"
+ "unicode/utf8"
+)
+
+const (
+ // Defaults used by HandleHTTP
+ DefaultRPCPath = "/_goRPC_"
+ DefaultDebugPath = "/debug/rpc"
+)
+
+// Precompute the reflect type for error. Can't use error directly
+// because Typeof takes an empty interface value. This is annoying.
+var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
+
+type methodType struct {
+ sync.Mutex // protects counters
+ method reflect.Method
+ ArgType reflect.Type
+ ReplyType reflect.Type
+ numCalls uint
+}
+
+type service struct {
+ name string // name of service
+ rcvr reflect.Value // receiver of methods for the service
+ typ reflect.Type // type of the receiver
+ method map[string]*methodType // registered methods
+}
+
+// Request is a header written before every RPC call. It is used internally
+// but documented here as an aid to debugging, such as when analyzing
+// network traffic.
+type Request struct {
+ ServiceMethod string // format: "Service.Method"
+ Seq uint64 // sequence number chosen by client
+ next *Request // for free list in Server
+}
+
+// Response is a header written before every RPC return. It is used internally
+// but documented here as an aid to debugging, such as when analyzing
+// network traffic.
+type Response struct {
+ ServiceMethod string // echoes that of the Request
+ Seq uint64 // echoes that of the request
+ Error string // error, if any.
+ next *Response // for free list in Server
+}
+
+// Server represents an RPC Server.
+type Server struct {
+ mu sync.Mutex // protects the serviceMap
+ serviceMap map[string]*service
+ reqLock sync.Mutex // protects freeReq
+ freeReq *Request
+ respLock sync.Mutex // protects freeResp
+ freeResp *Response
+}
+
+// NewServer returns a new Server.
+func NewServer() *Server {
+ return &Server{serviceMap: make(map[string]*service)}
+}
+
+// DefaultServer is the default instance of *Server.
+var DefaultServer = NewServer()
+
+// Is this an exported - upper case - name?
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// Is this type exported or a builtin?
+func isExportedOrBuiltinType(t reflect.Type) bool {
+ for t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+ // PkgPath will be non-empty even for an exported type,
+ // so we need to check the type name as well.
+ return isExported(t.Name()) || t.PkgPath() == ""
+}
+
+// Register publishes in the server the set of methods of the
+// receiver value that satisfy the following conditions:
+// - exported method
+// - two arguments, both pointers to exported structs
+// - one return value, of type error
+// It returns an error if the receiver is not an exported type or has no
+// suitable methods.
+// The client accesses each method using a string of the form "Type.Method",
+// where Type is the receiver's concrete type.
+func (server *Server) Register(rcvr interface{}) error {
+ return server.register(rcvr, "", false)
+}
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func (server *Server) RegisterName(name string, rcvr interface{}) error {
+ return server.register(rcvr, name, true)
+}
+
+func (server *Server) register(rcvr interface{}, name string, useName bool) error {
+ server.mu.Lock()
+ defer server.mu.Unlock()
+ if server.serviceMap == nil {
+ server.serviceMap = make(map[string]*service)
+ }
+ s := new(service)
+ s.typ = reflect.TypeOf(rcvr)
+ s.rcvr = reflect.ValueOf(rcvr)
+ sname := reflect.Indirect(s.rcvr).Type().Name()
+ if useName {
+ sname = name
+ }
+ if sname == "" {
+ log.Fatal("rpc: no service name for type", s.typ.String())
+ }
+ if !isExported(sname) && !useName {
+ s := "rpc Register: type " + sname + " is not exported"
+ log.Print(s)
+ return errors.New(s)
+ }
+ if _, present := server.serviceMap[sname]; present {
+ return errors.New("rpc: service already defined: " + sname)
+ }
+ s.name = sname
+ s.method = make(map[string]*methodType)
+
+ // Install the methods
+ for m := 0; m < s.typ.NumMethod(); m++ {
+ method := s.typ.Method(m)
+ mtype := method.Type
+ mname := method.Name
+ if method.PkgPath != "" {
+ continue
+ }
+ // Method needs three ins: receiver, *args, *reply.
+ if mtype.NumIn() != 3 {
+ log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
+ continue
+ }
+ // First arg need not be a pointer.
+ argType := mtype.In(1)
+ if !isExportedOrBuiltinType(argType) {
+ log.Println(mname, "argument type not exported or local:", argType)
+ continue
+ }
+ // Second arg must be a pointer.
+ replyType := mtype.In(2)
+ if replyType.Kind() != reflect.Ptr {
+ log.Println("method", mname, "reply type not a pointer:", replyType)
+ continue
+ }
+ if !isExportedOrBuiltinType(replyType) {
+ log.Println("method", mname, "reply type not exported or local:", replyType)
+ continue
+ }
+ // Method needs one out: error.
+ if mtype.NumOut() != 1 {
+ log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
+ continue
+ }
+ if returnType := mtype.Out(0); returnType != typeOfError {
+ log.Println("method", mname, "returns", returnType.String(), "not error")
+ continue
+ }
+ s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
+ }
+
+ if len(s.method) == 0 {
+ s := "rpc Register: type " + sname + " has no exported methods of suitable type"
+ log.Print(s)
+ return errors.New(s)
+ }
+ server.serviceMap[s.name] = s
+ return nil
+}
+
+// A value sent as a placeholder for the response when the server receives an invalid request.
+type InvalidRequest struct{}
+
+var invalidRequest = InvalidRequest{}
+
+func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
+ resp := server.getResponse()
+ // Encode the response header
+ resp.ServiceMethod = req.ServiceMethod
+ if errmsg != "" {
+ resp.Error = errmsg
+ reply = invalidRequest
+ }
+ resp.Seq = req.Seq
+ sending.Lock()
+ err := codec.WriteResponse(resp, reply)
+ if err != nil {
+ log.Println("rpc: writing response:", err)
+ }
+ sending.Unlock()
+ server.freeResponse(resp)
+}
+
+func (m *methodType) NumCalls() (n uint) {
+ m.Lock()
+ n = m.numCalls
+ m.Unlock()
+ return n
+}
+
+func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
+ mtype.Lock()
+ mtype.numCalls++
+ mtype.Unlock()
+ function := mtype.method.Func
+ // Invoke the method, providing a new value for the reply.
+ returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
+ // The return value for the method is an error.
+ errInter := returnValues[0].Interface()
+ errmsg := ""
+ if errInter != nil {
+ errmsg = errInter.(error).Error()
+ }
+ server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
+ server.freeRequest(req)
+}
+
+type gobServerCodec struct {
+ rwc io.ReadWriteCloser
+ dec *gob.Decoder
+ enc *gob.Encoder
+ encBuf *bufio.Writer
+}
+
+func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
+ return c.dec.Decode(r)
+}
+
+func (c *gobServerCodec) ReadRequestBody(body interface{}) error {
+ return c.dec.Decode(body)
+}
+
+func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err error) {
+ if err = c.enc.Encode(r); err != nil {
+ return
+ }
+ if err = c.enc.Encode(body); err != nil {
+ return
+ }
+ return c.encBuf.Flush()
+}
+
+func (c *gobServerCodec) Close() error {
+ return c.rwc.Close()
+}
+
+// ServeConn runs the server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection. To use an alternate codec, use ServeCodec.
+func (server *Server) ServeConn(conn io.ReadWriteCloser) {
+ buf := bufio.NewWriter(conn)
+ srv := &gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(buf), buf}
+ server.ServeCodec(srv)
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func (server *Server) ServeCodec(codec ServerCodec) {
+ sending := new(sync.Mutex)
+ for {
+ service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
+ if err != nil {
+ if err != io.EOF {
+ log.Println("rpc:", err)
+ }
+ if !keepReading {
+ break
+ }
+ // send a response if we actually managed to read a header.
+ if req != nil {
+ server.sendResponse(sending, req, invalidRequest, codec, err.Error())
+ server.freeRequest(req)
+ }
+ continue
+ }
+ go service.call(server, sending, mtype, req, argv, replyv, codec)
+ }
+ codec.Close()
+}
+
+// ServeRequest is like ServeCodec but synchronously serves a single request.
+// It does not close the codec upon completion.
+func (server *Server) ServeRequest(codec ServerCodec) error {
+ sending := new(sync.Mutex)
+ service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
+ if err != nil {
+ if !keepReading {
+ return err
+ }
+ // send a response if we actually managed to read a header.
+ if req != nil {
+ server.sendResponse(sending, req, invalidRequest, codec, err.Error())
+ server.freeRequest(req)
+ }
+ return err
+ }
+ service.call(server, sending, mtype, req, argv, replyv, codec)
+ return nil
+}
+
+func (server *Server) getRequest() *Request {
+ server.reqLock.Lock()
+ req := server.freeReq
+ if req == nil {
+ req = new(Request)
+ } else {
+ server.freeReq = req.next
+ *req = Request{}
+ }
+ server.reqLock.Unlock()
+ return req
+}
+
+func (server *Server) freeRequest(req *Request) {
+ server.reqLock.Lock()
+ req.next = server.freeReq
+ server.freeReq = req
+ server.reqLock.Unlock()
+}
+
+func (server *Server) getResponse() *Response {
+ server.respLock.Lock()
+ resp := server.freeResp
+ if resp == nil {
+ resp = new(Response)
+ } else {
+ server.freeResp = resp.next
+ *resp = Response{}
+ }
+ server.respLock.Unlock()
+ return resp
+}
+
+func (server *Server) freeResponse(resp *Response) {
+ server.respLock.Lock()
+ resp.next = server.freeResp
+ server.freeResp = resp
+ server.respLock.Unlock()
+}
+
+func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
+ service, mtype, req, keepReading, err = server.readRequestHeader(codec)
+ if err != nil {
+ if !keepReading {
+ return
+ }
+ // discard body
+ codec.ReadRequestBody(nil)
+ return
+ }
+
+ // Decode the argument value.
+ argIsValue := false // if true, need to indirect before calling.
+ if mtype.ArgType.Kind() == reflect.Ptr {
+ argv = reflect.New(mtype.ArgType.Elem())
+ } else {
+ argv = reflect.New(mtype.ArgType)
+ argIsValue = true
+ }
+ // argv guaranteed to be a pointer now.
+ if err = codec.ReadRequestBody(argv.Interface()); err != nil {
+ return
+ }
+ if argIsValue {
+ argv = argv.Elem()
+ }
+
+ replyv = reflect.New(mtype.ReplyType.Elem())
+ return
+}
+
+func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, keepReading bool, err error) {
+ // Grab the request header.
+ req = server.getRequest()
+ err = codec.ReadRequestHeader(req)
+ if err != nil {
+ req = nil
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ return
+ }
+ err = errors.New("rpc: server cannot decode request: " + err.Error())
+ return
+ }
+
+ // We read the header successfully. If we see an error now,
+ // we can still recover and move on to the next request.
+ keepReading = true
+
+ serviceMethod := strings.Split(req.ServiceMethod, ".")
+ if len(serviceMethod) != 2 {
+ err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
+ return
+ }
+ // Look up the request.
+ server.mu.Lock()
+ service = server.serviceMap[serviceMethod[0]]
+ server.mu.Unlock()
+ if service == nil {
+ err = errors.New("rpc: can't find service " + req.ServiceMethod)
+ return
+ }
+ mtype = service.method[serviceMethod[1]]
+ if mtype == nil {
+ err = errors.New("rpc: can't find method " + req.ServiceMethod)
+ }
+ return
+}
+
+// Accept accepts connections on the listener and serves requests
+// for each incoming connection. Accept blocks; the caller typically
+// invokes it in a go statement.
+func (server *Server) Accept(lis net.Listener) {
+ for {
+ conn, err := lis.Accept()
+ if err != nil {
+ log.Fatal("rpc.Serve: accept:", err.Error()) // TODO(r): exit?
+ }
+ go server.ServeConn(conn)
+ }
+}
+
+// Register publishes the receiver's methods in the DefaultServer.
+func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func RegisterName(name string, rcvr interface{}) error {
+ return DefaultServer.RegisterName(name, rcvr)
+}
+
+// A ServerCodec implements reading of RPC requests and writing of
+// RPC responses for the server side of an RPC session.
+// The server calls ReadRequestHeader and ReadRequestBody in pairs
+// to read requests from the connection, and it calls WriteResponse to
+// write a response back. The server calls Close when finished with the
+// connection. ReadRequestBody may be called with a nil
+// argument to force the body of the request to be read and discarded.
+type ServerCodec interface {
+ ReadRequestHeader(*Request) error
+ ReadRequestBody(interface{}) error
+ WriteResponse(*Response, interface{}) error
+
+ Close() error
+}
+
+// ServeConn runs the DefaultServer on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection. To use an alternate codec, use ServeCodec.
+func ServeConn(conn io.ReadWriteCloser) {
+ DefaultServer.ServeConn(conn)
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func ServeCodec(codec ServerCodec) {
+ DefaultServer.ServeCodec(codec)
+}
+
+// ServeRequest is like ServeCodec but synchronously serves a single request.
+// It does not close the codec upon completion.
+func ServeRequest(codec ServerCodec) error {
+ return DefaultServer.ServeRequest(codec)
+}
+
+// Accept accepts connections on the listener and serves requests
+// to DefaultServer for each incoming connection.
+// Accept blocks; the caller typically invokes it in a go statement.
+func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
+
+// Can connect to RPC service using HTTP CONNECT to rpcPath.
+var connected = "200 Connected to Go RPC"
+
+// ServeHTTP implements an http.Handler that answers RPC requests.
+func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ if req.Method != "CONNECT" {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ io.WriteString(w, "405 must CONNECT\n")
+ return
+ }
+ conn, _, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
+ return
+ }
+ io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
+ server.ServeConn(conn)
+}
+
+// HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
+// and a debugging handler on debugPath.
+// It is still necessary to invoke http.Serve(), typically in a go statement.
+func (server *Server) HandleHTTP(rpcPath, debugPath string) {
+ http.Handle(rpcPath, server)
+ http.Handle(debugPath, debugHTTP{server})
+}
+
+// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
+// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
+// It is still necessary to invoke http.Serve(), typically in a go statement.
+func HandleHTTP() {
+ DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
+}
diff --git a/src/pkg/net/rpc/server_test.go b/src/pkg/net/rpc/server_test.go
new file mode 100644
index 000000000..b05c63c05
--- /dev/null
+++ b/src/pkg/net/rpc/server_test.go
@@ -0,0 +1,596 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http/httptest"
+ "runtime"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+var (
+ newServer *Server
+ serverAddr, newServerAddr string
+ httpServerAddr string
+ once, newOnce, httpOnce sync.Once
+)
+
+const (
+ newHttpPath = "/foo"
+)
+
+type Args struct {
+ A, B int
+}
+
+type Reply struct {
+ C int
+}
+
+type Arith int
+
+// Some of Arith's methods have value args, some have pointer args. That's deliberate.
+
+func (t *Arith) Add(args Args, reply *Reply) error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+func (t *Arith) Mul(args *Args, reply *Reply) error {
+ reply.C = args.A * args.B
+ return nil
+}
+
+func (t *Arith) Div(args Args, reply *Reply) error {
+ if args.B == 0 {
+ return errors.New("divide by zero")
+ }
+ reply.C = args.A / args.B
+ return nil
+}
+
+func (t *Arith) String(args *Args, reply *string) error {
+ *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
+ return nil
+}
+
+func (t *Arith) Scan(args string, reply *Reply) (err error) {
+ _, err = fmt.Sscan(args, &reply.C)
+ return
+}
+
+func (t *Arith) Error(args *Args, reply *Reply) error {
+ panic("ERROR")
+}
+
+func listenTCP() (net.Listener, string) {
+ l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
+ if e != nil {
+ log.Fatalf("net.Listen tcp :0: %v", e)
+ }
+ return l, l.Addr().String()
+}
+
+func startServer() {
+ Register(new(Arith))
+
+ var l net.Listener
+ l, serverAddr = listenTCP()
+ log.Println("Test RPC server listening on", serverAddr)
+ go Accept(l)
+
+ HandleHTTP()
+ httpOnce.Do(startHttpServer)
+}
+
+func startNewServer() {
+ newServer = NewServer()
+ newServer.Register(new(Arith))
+
+ var l net.Listener
+ l, newServerAddr = listenTCP()
+ log.Println("NewServer test RPC server listening on", newServerAddr)
+ go Accept(l)
+
+ newServer.HandleHTTP(newHttpPath, "/bar")
+ httpOnce.Do(startHttpServer)
+}
+
+func startHttpServer() {
+ server := httptest.NewServer(nil)
+ httpServerAddr = server.Listener.Addr().String()
+ log.Println("Test HTTP RPC server listening on", httpServerAddr)
+}
+
+func TestRPC(t *testing.T) {
+ once.Do(startServer)
+ testRPC(t, serverAddr)
+ newOnce.Do(startNewServer)
+ testRPC(t, newServerAddr)
+}
+
+func testRPC(t *testing.T, addr string) {
+ client, err := Dial("tcp", addr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ // Nonexistent method
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.BadOperation", args, reply)
+ // expect an error
+ if err == nil {
+ t.Error("BadOperation: expected error")
+ } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
+ t.Errorf("BadOperation: expected can't find method error; got %q", err)
+ }
+
+ // Unknown service
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Unknown", args, reply)
+ if err == nil {
+ t.Error("expected error calling unknown service")
+ } else if strings.Index(err.Error(), "method") < 0 {
+ t.Error("expected error about method; got", err)
+ }
+
+ // Out of order.
+ args = &Args{7, 8}
+ mulReply := new(Reply)
+ mulCall := client.Go("Arith.Mul", args, mulReply, nil)
+ addReply := new(Reply)
+ addCall := client.Go("Arith.Add", args, addReply, nil)
+
+ addCall = <-addCall.Done
+ if addCall.Error != nil {
+ t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
+ }
+ if addReply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+ }
+
+ mulCall = <-mulCall.Done
+ if mulCall.Error != nil {
+ t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
+ }
+ if mulReply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+ }
+
+ // Error test
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.Div", args, reply)
+ // expect an error: zero divide
+ if err == nil {
+ t.Error("Div: expected error")
+ } else if err.Error() != "divide by zero" {
+ t.Error("Div: expected divide by zero error; got", err)
+ }
+
+ // Bad type.
+ reply = new(Reply)
+ err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
+ if err == nil {
+ t.Error("expected error calling Arith.Add with wrong arg type")
+ } else if strings.Index(err.Error(), "type") < 0 {
+ t.Error("expected error about type; got", err)
+ }
+
+ // Non-struct argument
+ const Val = 12345
+ str := fmt.Sprint(Val)
+ reply = new(Reply)
+ err = client.Call("Arith.Scan", &str, reply)
+ if err != nil {
+ t.Errorf("Scan: expected no error but got string %q", err.Error())
+ } else if reply.C != Val {
+ t.Errorf("Scan: expected %d got %d", Val, reply.C)
+ }
+
+ // Non-struct reply
+ args = &Args{27, 35}
+ str = ""
+ err = client.Call("Arith.String", args, &str)
+ if err != nil {
+ t.Errorf("String: expected no error but got string %q", err.Error())
+ }
+ expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
+ if str != expect {
+ t.Errorf("String: expected %s got %s", expect, str)
+ }
+
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Mul", args, reply)
+ if err != nil {
+ t.Errorf("Mul: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ }
+}
+
+func TestHTTP(t *testing.T) {
+ once.Do(startServer)
+ testHTTPRPC(t, "")
+ newOnce.Do(startNewServer)
+ testHTTPRPC(t, newHttpPath)
+}
+
+func testHTTPRPC(t *testing.T, path string) {
+ var client *Client
+ var err error
+ if path == "" {
+ client, err = DialHTTP("tcp", httpServerAddr)
+ } else {
+ client, err = DialHTTPPath("tcp", httpServerAddr, path)
+ }
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+}
+
+// CodecEmulator provides a client-like api and a ServerCodec interface.
+// Can be used to test ServeRequest.
+type CodecEmulator struct {
+ server *Server
+ serviceMethod string
+ args *Args
+ reply *Reply
+ err error
+}
+
+func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
+ codec.serviceMethod = serviceMethod
+ codec.args = args
+ codec.reply = reply
+ codec.err = nil
+ var serverError error
+ if codec.server == nil {
+ serverError = ServeRequest(codec)
+ } else {
+ serverError = codec.server.ServeRequest(codec)
+ }
+ if codec.err == nil && serverError != nil {
+ codec.err = serverError
+ }
+ return codec.err
+}
+
+func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
+ req.ServiceMethod = codec.serviceMethod
+ req.Seq = 0
+ return nil
+}
+
+func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
+ if codec.args == nil {
+ return io.ErrUnexpectedEOF
+ }
+ *(argv.(*Args)) = *codec.args
+ return nil
+}
+
+func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error {
+ if resp.Error != "" {
+ codec.err = errors.New(resp.Error)
+ } else {
+ *codec.reply = *(reply.(*Reply))
+ }
+ return nil
+}
+
+func (codec *CodecEmulator) Close() error {
+ return nil
+}
+
+func TestServeRequest(t *testing.T) {
+ once.Do(startServer)
+ testServeRequest(t, nil)
+ newOnce.Do(startNewServer)
+ testServeRequest(t, newServer)
+}
+
+func testServeRequest(t *testing.T, server *Server) {
+ client := CodecEmulator{server: server}
+
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ err = client.Call("Arith.Add", nil, reply)
+ if err == nil {
+ t.Errorf("expected error calling Arith.Add with nil arg")
+ }
+}
+
+type ReplyNotPointer int
+type ArgNotPublic int
+type ReplyNotPublic int
+type local struct{}
+
+func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
+ return nil
+}
+
+func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
+ return nil
+}
+
+func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
+ return nil
+}
+
+// Check that registration handles lots of bad methods and a type with no suitable methods.
+func TestRegistrationError(t *testing.T) {
+ err := Register(new(ReplyNotPointer))
+ if err == nil {
+ t.Errorf("expected error registering ReplyNotPointer")
+ }
+ err = Register(new(ArgNotPublic))
+ if err == nil {
+ t.Errorf("expected error registering ArgNotPublic")
+ }
+ err = Register(new(ReplyNotPublic))
+ if err == nil {
+ t.Errorf("expected error registering ReplyNotPublic")
+ }
+}
+
+type WriteFailCodec int
+
+func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
+ // the panic caused by this error used to not unlock a lock.
+ return errors.New("fail")
+}
+
+func (WriteFailCodec) ReadResponseHeader(*Response) error {
+ time.Sleep(120 * time.Second)
+ panic("unreachable")
+}
+
+func (WriteFailCodec) ReadResponseBody(interface{}) error {
+ time.Sleep(120 * time.Second)
+ panic("unreachable")
+}
+
+func (WriteFailCodec) Close() error {
+ return nil
+}
+
+func TestSendDeadlock(t *testing.T) {
+ client := NewClientWithCodec(WriteFailCodec(0))
+
+ done := make(chan bool)
+ go func() {
+ testSendDeadlock(client)
+ testSendDeadlock(client)
+ done <- true
+ }()
+ select {
+ case <-done:
+ return
+ case <-time.After(5 * time.Second):
+ t.Fatal("deadlock")
+ }
+}
+
+func testSendDeadlock(client *Client) {
+ defer func() {
+ recover()
+ }()
+ args := &Args{7, 8}
+ reply := new(Reply)
+ client.Call("Arith.Add", args, reply)
+}
+
+func dialDirect() (*Client, error) {
+ return Dial("tcp", serverAddr)
+}
+
+func dialHTTP() (*Client, error) {
+ return DialHTTP("tcp", httpServerAddr)
+}
+
+func countMallocs(dial func() (*Client, error), t *testing.T) uint64 {
+ once.Do(startServer)
+ client, err := dial()
+ if err != nil {
+ t.Fatal("error dialing", err)
+ }
+ args := &Args{7, 8}
+ reply := new(Reply)
+ runtime.UpdateMemStats()
+ mallocs := 0 - runtime.MemStats.Mallocs
+ const count = 100
+ for i := 0; i < count; i++ {
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+ }
+ runtime.UpdateMemStats()
+ mallocs += runtime.MemStats.Mallocs
+ return mallocs / count
+}
+
+func TestCountMallocs(t *testing.T) {
+ fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t))
+}
+
+func TestCountMallocsOverHTTP(t *testing.T) {
+ fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t))
+}
+
+type writeCrasher struct {
+ done chan bool
+}
+
+func (writeCrasher) Close() error {
+ return nil
+}
+
+func (w *writeCrasher) Read(p []byte) (int, error) {
+ <-w.done
+ return 0, io.EOF
+}
+
+func (writeCrasher) Write(p []byte) (int, error) {
+ return 0, errors.New("fake write failure")
+}
+
+func TestClientWriteError(t *testing.T) {
+ w := &writeCrasher{done: make(chan bool)}
+ c := NewClient(w)
+ res := false
+ err := c.Call("foo", 1, &res)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if err.Error() != "fake write failure" {
+ t.Error("unexpected value of error:", err)
+ }
+ w.done <- true
+}
+
+func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
+ b.StopTimer()
+ once.Do(startServer)
+ client, err := dial()
+ if err != nil {
+ b.Fatal("error dialing:", err)
+ }
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ procs := runtime.GOMAXPROCS(-1)
+ N := int32(b.N)
+ var wg sync.WaitGroup
+ wg.Add(procs)
+ b.StartTimer()
+
+ for p := 0; p < procs; p++ {
+ go func() {
+ reply := new(Reply)
+ for atomic.AddInt32(&N, -1) >= 0 {
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+}
+
+func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
+ const MaxConcurrentCalls = 100
+ b.StopTimer()
+ once.Do(startServer)
+ client, err := dial()
+ if err != nil {
+ b.Fatal("error dialing:", err)
+ }
+
+ // Asynchronous calls
+ args := &Args{7, 8}
+ procs := 4 * runtime.GOMAXPROCS(-1)
+ send := int32(b.N)
+ recv := int32(b.N)
+ var wg sync.WaitGroup
+ wg.Add(procs)
+ gate := make(chan bool, MaxConcurrentCalls)
+ res := make(chan *Call, MaxConcurrentCalls)
+ b.StartTimer()
+
+ for p := 0; p < procs; p++ {
+ go func() {
+ for atomic.AddInt32(&send, -1) >= 0 {
+ gate <- true
+ reply := new(Reply)
+ client.Go("Arith.Add", args, reply, res)
+ }
+ }()
+ go func() {
+ for call := range res {
+ A := call.Args.(*Args).A
+ B := call.Args.(*Args).B
+ C := call.Reply.(*Reply).C
+ if A+B != C {
+ b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C)
+ }
+ <-gate
+ if atomic.AddInt32(&recv, -1) == 0 {
+ close(res)
+ }
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkEndToEnd(b *testing.B) {
+ benchmarkEndToEnd(dialDirect, b)
+}
+
+func BenchmarkEndToEndHTTP(b *testing.B) {
+ benchmarkEndToEnd(dialHTTP, b)
+}
+
+func BenchmarkEndToEndAsync(b *testing.B) {
+ benchmarkEndToEndAsync(dialDirect, b)
+}
+
+func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
+ benchmarkEndToEndAsync(dialHTTP, b)
+}
diff --git a/src/pkg/net/sendfile_linux.go b/src/pkg/net/sendfile_linux.go
index 6a5a06c8c..e9ab06666 100644
--- a/src/pkg/net/sendfile_linux.go
+++ b/src/pkg/net/sendfile_linux.go
@@ -21,7 +21,7 @@ const maxSendfileSize int = 4 << 20
// non-EOF error.
//
// if handled == false, sendFile performed no work.
-func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool) {
+func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
var remain int64 = 1 << 62 // by default, copy until EOF
lr, ok := r.(*io.LimitedReader)
@@ -40,15 +40,6 @@ func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool)
defer c.wio.Unlock()
c.incref()
defer c.decref()
- if c.wdeadline_delta > 0 {
- // This is a little odd that we're setting the timeout
- // for the entire file but Write has the same issue
- // (if one slurps the whole file into memory and
- // do one large Write). At least they're consistent.
- c.wdeadline = pollserver.Now() + c.wdeadline_delta
- } else {
- c.wdeadline = 0
- }
dst := c.sysfd
src := f.Fd()
@@ -62,18 +53,18 @@ func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool)
written += int64(n)
remain -= int64(n)
}
- if n == 0 && errno == 0 {
+ if n == 0 && errno == nil {
break
}
if errno == syscall.EAGAIN && c.wdeadline >= 0 {
pollserver.WaitWrite(c)
continue
}
- if errno != 0 {
+ if errno != nil {
// This includes syscall.ENOSYS (no kernel
// support) and syscall.EINVAL (fd types which
// don't implement sendfile together)
- err = &OpError{"sendfile", c.net, c.raddr, os.Errno(errno)}
+ err = &OpError{"sendfile", c.net, c.raddr, errno}
break
}
}
diff --git a/src/pkg/net/sendfile_stub.go b/src/pkg/net/sendfile_stub.go
index c55be6c08..ff76ab9cf 100644
--- a/src/pkg/net/sendfile_stub.go
+++ b/src/pkg/net/sendfile_stub.go
@@ -2,15 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd openbsd
+// +build darwin freebsd netbsd openbsd
package net
-import (
- "io"
- "os"
-)
+import "io"
-func sendFile(c *netFD, r io.Reader) (n int64, err os.Error, handled bool) {
+func sendFile(c *netFD, r io.Reader) (n int64, err error, handled bool) {
return 0, nil, false
}
diff --git a/src/pkg/net/sendfile_windows.go b/src/pkg/net/sendfile_windows.go
index d9c2f537a..ee7ff8b98 100644
--- a/src/pkg/net/sendfile_windows.go
+++ b/src/pkg/net/sendfile_windows.go
@@ -16,7 +16,7 @@ type sendfileOp struct {
n uint32
}
-func (o *sendfileOp) Submit() (errno int) {
+func (o *sendfileOp) Submit() (err error) {
return syscall.TransmitFile(o.fd.sysfd, o.src, o.n, 0, &o.o, nil, syscall.TF_WRITE_BEHIND)
}
@@ -33,7 +33,7 @@ func (o *sendfileOp) Name() string {
// if handled == false, sendFile performed no work.
//
// Note that sendfile for windows does not suppport >2GB file.
-func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool) {
+func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
var n int64 = 0 // by default, copy until EOF
lr, ok := r.(*io.LimitedReader)
diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go
index a2ff218e7..b0b546be3 100644
--- a/src/pkg/net/server_test.go
+++ b/src/pkg/net/server_test.go
@@ -8,10 +8,10 @@ import (
"flag"
"io"
"os"
+ "runtime"
"strings"
- "syscall"
"testing"
- "runtime"
+ "time"
)
// Do not test empty datagrams by default.
@@ -55,7 +55,7 @@ func runServe(t *testing.T, network, addr string, listening chan<- string, done
func connect(t *testing.T, network, addr string, isEmpty bool) {
var fd Conn
- var err os.Error
+ var err error
if network == "unixgram" {
fd, err = DialUnix(network, &UnixAddr{addr + ".local", network}, &UnixAddr{addr, network})
} else {
@@ -64,7 +64,7 @@ func connect(t *testing.T, network, addr string, isEmpty bool) {
if err != nil {
t.Fatalf("net.Dial(%q, %q) = _, %v", network, addr, err)
}
- fd.SetReadTimeout(1e9) // 1s
+ fd.SetReadDeadline(time.Now().Add(1 * time.Second))
var b []byte
if !isEmpty {
@@ -92,7 +92,7 @@ func connect(t *testing.T, network, addr string, isEmpty bool) {
}
func doTest(t *testing.T, network, listenaddr, dialaddr string) {
- t.Logf("Test %q %q %q\n", network, listenaddr, dialaddr)
+ t.Logf("Test %q %q %q", network, listenaddr, dialaddr)
switch listenaddr {
case "", "0.0.0.0", "[::]", "[::ffff:0.0.0.0]":
if testing.Short() || avoidMacFirewall {
@@ -115,7 +115,7 @@ func doTest(t *testing.T, network, listenaddr, dialaddr string) {
}
func TestTCPServer(t *testing.T) {
- if syscall.OS != "openbsd" {
+ if runtime.GOOS != "openbsd" {
doTest(t, "tcp", "", "127.0.0.1")
}
doTest(t, "tcp", "0.0.0.0", "127.0.0.1")
@@ -155,7 +155,7 @@ func TestUnixServer(t *testing.T) {
os.Remove("/tmp/gotest.net")
doTest(t, "unix", "/tmp/gotest.net", "/tmp/gotest.net")
os.Remove("/tmp/gotest.net")
- if syscall.OS == "linux" {
+ if runtime.GOOS == "linux" {
doTest(t, "unixpacket", "/tmp/gotest.net", "/tmp/gotest.net")
os.Remove("/tmp/gotest.net")
// Test abstract unix domain socket, a Linux-ism
@@ -170,10 +170,10 @@ func runPacket(t *testing.T, network, addr string, listening chan<- string, done
t.Fatalf("net.ListenPacket(%q, %q) = _, %v", network, addr, err)
}
listening <- c.LocalAddr().String()
- c.SetReadTimeout(10e6) // 10ms
var buf [1000]byte
Run:
for {
+ c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
n, addr, err := c.ReadFrom(buf[0:])
if e, ok := err.(Error); ok && e.Timeout() {
select {
@@ -195,7 +195,7 @@ Run:
}
func doTestPacket(t *testing.T, network, listenaddr, dialaddr string, isEmpty bool) {
- t.Logf("TestPacket %s %s %s\n", network, listenaddr, dialaddr)
+ t.Logf("TestPacket %q %q %q", network, listenaddr, dialaddr)
listening := make(chan string)
done := make(chan int)
if network == "udp" {
@@ -237,7 +237,7 @@ func TestUnixDatagramServer(t *testing.T) {
doTestPacket(t, "unixgram", "/tmp/gotest1.net", "/tmp/gotest1.net", isEmpty)
os.Remove("/tmp/gotest1.net")
os.Remove("/tmp/gotest1.net.local")
- if syscall.OS == "linux" {
+ if runtime.GOOS == "linux" {
// Test abstract unix domain socket, a Linux-ism
doTestPacket(t, "unixgram", "@gotest1/net", "@gotest1/net", isEmpty)
}
diff --git a/src/pkg/net/smtp/Makefile b/src/pkg/net/smtp/Makefile
new file mode 100644
index 000000000..d9812d5cb
--- /dev/null
+++ b/src/pkg/net/smtp/Makefile
@@ -0,0 +1,12 @@
+# Copyright 2010 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../Make.inc
+
+TARG=net/smtp
+GOFILES=\
+ auth.go\
+ smtp.go\
+
+include ../../../Make.pkg
diff --git a/src/pkg/net/smtp/auth.go b/src/pkg/net/smtp/auth.go
new file mode 100644
index 000000000..d401e3c21
--- /dev/null
+++ b/src/pkg/net/smtp/auth.go
@@ -0,0 +1,98 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package smtp
+
+import (
+ "crypto/hmac"
+ "crypto/md5"
+ "errors"
+ "fmt"
+)
+
+// Auth is implemented by an SMTP authentication mechanism.
+type Auth interface {
+ // Start begins an authentication with a server.
+ // It returns the name of the authentication protocol
+ // and optionally data to include in the initial AUTH message
+ // sent to the server. It can return proto == "" to indicate
+ // that the authentication should be skipped.
+ // If it returns a non-nil error, the SMTP client aborts
+ // the authentication attempt and closes the connection.
+ Start(server *ServerInfo) (proto string, toServer []byte, err error)
+
+ // Next continues the authentication. The server has just sent
+ // the fromServer data. If more is true, the server expects a
+ // response, which Next should return as toServer; otherwise
+ // Next should return toServer == nil.
+ // If Next returns a non-nil error, the SMTP client aborts
+ // the authentication attempt and closes the connection.
+ Next(fromServer []byte, more bool) (toServer []byte, err error)
+}
+
+// ServerInfo records information about an SMTP server.
+type ServerInfo struct {
+ Name string // SMTP server name
+ TLS bool // using TLS, with valid certificate for Name
+ Auth []string // advertised authentication mechanisms
+}
+
+type plainAuth struct {
+ identity, username, password string
+ host string
+}
+
+// PlainAuth returns an Auth that implements the PLAIN authentication
+// mechanism as defined in RFC 4616.
+// The returned Auth uses the given username and password to authenticate
+// on TLS connections to host and act as identity. Usually identity will be
+// left blank to act as username.
+func PlainAuth(identity, username, password, host string) Auth {
+ return &plainAuth{identity, username, password, host}
+}
+
+func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) {
+ if !server.TLS {
+ return "", nil, errors.New("unencrypted connection")
+ }
+ if server.Name != a.host {
+ return "", nil, errors.New("wrong host name")
+ }
+ resp := []byte(a.identity + "\x00" + a.username + "\x00" + a.password)
+ return "PLAIN", resp, nil
+}
+
+func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) {
+ if more {
+ // We've already sent everything.
+ return nil, errors.New("unexpected server challenge")
+ }
+ return nil, nil
+}
+
+type cramMD5Auth struct {
+ username, secret string
+}
+
+// CRAMMD5Auth returns an Auth that implements the CRAM-MD5 authentication
+// mechanism as defined in RFC 2195.
+// The returned Auth uses the given username and secret to authenticate
+// to the server using the challenge-response mechanism.
+func CRAMMD5Auth(username, secret string) Auth {
+ return &cramMD5Auth{username, secret}
+}
+
+func (a *cramMD5Auth) Start(server *ServerInfo) (string, []byte, error) {
+ return "CRAM-MD5", nil, nil
+}
+
+func (a *cramMD5Auth) Next(fromServer []byte, more bool) ([]byte, error) {
+ if more {
+ d := hmac.New(md5.New, []byte(a.secret))
+ d.Write(fromServer)
+ s := make([]byte, 0, d.Size())
+ return []byte(fmt.Sprintf("%s %x", a.username, d.Sum(s))), nil
+ }
+ return nil, nil
+}
diff --git a/src/pkg/net/smtp/smtp.go b/src/pkg/net/smtp/smtp.go
new file mode 100644
index 000000000..8d935ffb7
--- /dev/null
+++ b/src/pkg/net/smtp/smtp.go
@@ -0,0 +1,293 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package smtp implements the Simple Mail Transfer Protocol as defined in RFC 5321.
+// It also implements the following extensions:
+// 8BITMIME RFC 1652
+// AUTH RFC 2554
+// STARTTLS RFC 3207
+// Additional extensions may be handled by clients.
+package smtp
+
+import (
+ "crypto/tls"
+ "encoding/base64"
+ "io"
+ "net"
+ "net/textproto"
+ "strings"
+)
+
+// A Client represents a client connection to an SMTP server.
+type Client struct {
+ // Text is the textproto.Conn used by the Client. It is exported to allow for
+ // clients to add extensions.
+ Text *textproto.Conn
+ // keep a reference to the connection so it can be used to create a TLS
+ // connection later
+ conn net.Conn
+ // whether the Client is using TLS
+ tls bool
+ serverName string
+ // map of supported extensions
+ ext map[string]string
+ // supported auth mechanisms
+ auth []string
+}
+
+// Dial returns a new Client connected to an SMTP server at addr.
+func Dial(addr string) (*Client, error) {
+ conn, err := net.Dial("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ host := addr[:strings.Index(addr, ":")]
+ return NewClient(conn, host)
+}
+
+// NewClient returns a new Client using an existing connection and host as a
+// server name to be used when authenticating.
+func NewClient(conn net.Conn, host string) (*Client, error) {
+ text := textproto.NewConn(conn)
+ _, msg, err := text.ReadResponse(220)
+ if err != nil {
+ text.Close()
+ return nil, err
+ }
+ c := &Client{Text: text, conn: conn, serverName: host}
+ if strings.Contains(msg, "ESMTP") {
+ err = c.ehlo()
+ } else {
+ err = c.helo()
+ }
+ return c, err
+}
+
+// cmd is a convenience function that sends a command and returns the response
+func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
+ id, err := c.Text.Cmd(format, args...)
+ if err != nil {
+ return 0, "", err
+ }
+ c.Text.StartResponse(id)
+ defer c.Text.EndResponse(id)
+ code, msg, err := c.Text.ReadResponse(expectCode)
+ return code, msg, err
+}
+
+// helo sends the HELO greeting to the server. It should be used only when the
+// server does not support ehlo.
+func (c *Client) helo() error {
+ c.ext = nil
+ _, _, err := c.cmd(250, "HELO localhost")
+ return err
+}
+
+// ehlo sends the EHLO (extended hello) greeting to the server. It
+// should be the preferred greeting for servers that support it.
+func (c *Client) ehlo() error {
+ _, msg, err := c.cmd(250, "EHLO localhost")
+ if err != nil {
+ return err
+ }
+ ext := make(map[string]string)
+ extList := strings.Split(msg, "\n")
+ if len(extList) > 1 {
+ extList = extList[1:]
+ for _, line := range extList {
+ args := strings.SplitN(line, " ", 2)
+ if len(args) > 1 {
+ ext[args[0]] = args[1]
+ } else {
+ ext[args[0]] = ""
+ }
+ }
+ }
+ if mechs, ok := ext["AUTH"]; ok {
+ c.auth = strings.Split(mechs, " ")
+ }
+ c.ext = ext
+ return err
+}
+
+// StartTLS sends the STARTTLS command and encrypts all further communication.
+// Only servers that advertise the STARTTLS extension support this function.
+func (c *Client) StartTLS(config *tls.Config) error {
+ _, _, err := c.cmd(220, "STARTTLS")
+ if err != nil {
+ return err
+ }
+ c.conn = tls.Client(c.conn, config)
+ c.Text = textproto.NewConn(c.conn)
+ c.tls = true
+ return c.ehlo()
+}
+
+// Verify checks the validity of an email address on the server.
+// If Verify returns nil, the address is valid. A non-nil return
+// does not necessarily indicate an invalid address. Many servers
+// will not verify addresses for security reasons.
+func (c *Client) Verify(addr string) error {
+ _, _, err := c.cmd(250, "VRFY %s", addr)
+ return err
+}
+
+// Auth authenticates a client using the provided authentication mechanism.
+// A failed authentication closes the connection.
+// Only servers that advertise the AUTH extension support this function.
+func (c *Client) Auth(a Auth) error {
+ encoding := base64.StdEncoding
+ mech, resp, err := a.Start(&ServerInfo{c.serverName, c.tls, c.auth})
+ if err != nil {
+ c.Quit()
+ return err
+ }
+ resp64 := make([]byte, encoding.EncodedLen(len(resp)))
+ encoding.Encode(resp64, resp)
+ code, msg64, err := c.cmd(0, "AUTH %s %s", mech, resp64)
+ for err == nil {
+ var msg []byte
+ switch code {
+ case 334:
+ msg, err = encoding.DecodeString(msg64)
+ case 235:
+ // the last message isn't base64 because it isn't a challenge
+ msg = []byte(msg64)
+ default:
+ err = &textproto.Error{code, msg64}
+ }
+ resp, err = a.Next(msg, code == 334)
+ if err != nil {
+ // abort the AUTH
+ c.cmd(501, "*")
+ c.Quit()
+ break
+ }
+ if resp == nil {
+ break
+ }
+ resp64 = make([]byte, encoding.EncodedLen(len(resp)))
+ encoding.Encode(resp64, resp)
+ code, msg64, err = c.cmd(0, string(resp64))
+ }
+ return err
+}
+
+// Mail issues a MAIL command to the server using the provided email address.
+// If the server supports the 8BITMIME extension, Mail adds the BODY=8BITMIME
+// parameter.
+// This initiates a mail transaction and is followed by one or more Rcpt calls.
+func (c *Client) Mail(from string) error {
+ cmdStr := "MAIL FROM:<%s>"
+ if c.ext != nil {
+ if _, ok := c.ext["8BITMIME"]; ok {
+ cmdStr += " BODY=8BITMIME"
+ }
+ }
+ _, _, err := c.cmd(250, cmdStr, from)
+ return err
+}
+
+// Rcpt issues a RCPT command to the server using the provided email address.
+// A call to Rcpt must be preceded by a call to Mail and may be followed by
+// a Data call or another Rcpt call.
+func (c *Client) Rcpt(to string) error {
+ _, _, err := c.cmd(25, "RCPT TO:<%s>", to)
+ return err
+}
+
+type dataCloser struct {
+ c *Client
+ io.WriteCloser
+}
+
+func (d *dataCloser) Close() error {
+ d.WriteCloser.Close()
+ _, _, err := d.c.Text.ReadResponse(250)
+ return err
+}
+
+// Data issues a DATA command to the server and returns a writer that
+// can be used to write the data. The caller should close the writer
+// before calling any more methods on c.
+// A call to Data must be preceded by one or more calls to Rcpt.
+func (c *Client) Data() (io.WriteCloser, error) {
+ _, _, err := c.cmd(354, "DATA")
+ if err != nil {
+ return nil, err
+ }
+ return &dataCloser{c, c.Text.DotWriter()}, nil
+}
+
+// SendMail connects to the server at addr, switches to TLS if possible,
+// authenticates with mechanism a if possible, and then sends an email from
+// address from, to addresses to, with message msg.
+func SendMail(addr string, a Auth, from string, to []string, msg []byte) error {
+ c, err := Dial(addr)
+ if err != nil {
+ return err
+ }
+ if ok, _ := c.Extension("STARTTLS"); ok {
+ if err = c.StartTLS(nil); err != nil {
+ return err
+ }
+ }
+ if a != nil && c.ext != nil {
+ if _, ok := c.ext["AUTH"]; ok {
+ if err = c.Auth(a); err != nil {
+ return err
+ }
+ }
+ }
+ if err = c.Mail(from); err != nil {
+ return err
+ }
+ for _, addr := range to {
+ if err = c.Rcpt(addr); err != nil {
+ return err
+ }
+ }
+ w, err := c.Data()
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(msg)
+ if err != nil {
+ return err
+ }
+ err = w.Close()
+ if err != nil {
+ return err
+ }
+ return c.Quit()
+}
+
+// Extension reports whether an extension is support by the server.
+// The extension name is case-insensitive. If the extension is supported,
+// Extension also returns a string that contains any parameters the
+// server specifies for the extension.
+func (c *Client) Extension(ext string) (bool, string) {
+ if c.ext == nil {
+ return false, ""
+ }
+ ext = strings.ToUpper(ext)
+ param, ok := c.ext[ext]
+ return ok, param
+}
+
+// Reset sends the RSET command to the server, aborting the current mail
+// transaction.
+func (c *Client) Reset() error {
+ _, _, err := c.cmd(250, "RSET")
+ return err
+}
+
+// Quit sends the QUIT command and closes the connection to the server.
+func (c *Client) Quit() error {
+ _, _, err := c.cmd(221, "QUIT")
+ if err != nil {
+ return err
+ }
+ return c.Text.Close()
+}
diff --git a/src/pkg/net/smtp/smtp_test.go b/src/pkg/net/smtp/smtp_test.go
new file mode 100644
index 000000000..ce8878205
--- /dev/null
+++ b/src/pkg/net/smtp/smtp_test.go
@@ -0,0 +1,182 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package smtp
+
+import (
+ "bufio"
+ "bytes"
+ "io"
+ "net/textproto"
+ "strings"
+ "testing"
+)
+
+type authTest struct {
+ auth Auth
+ challenges []string
+ name string
+ responses []string
+}
+
+var authTests = []authTest{
+ {PlainAuth("", "user", "pass", "testserver"), []string{}, "PLAIN", []string{"\x00user\x00pass"}},
+ {PlainAuth("foo", "bar", "baz", "testserver"), []string{}, "PLAIN", []string{"foo\x00bar\x00baz"}},
+ {CRAMMD5Auth("user", "pass"), []string{"<123456.1322876914@testserver>"}, "CRAM-MD5", []string{"", "user 287eb355114cf5c471c26a875f1ca4ae"}},
+}
+
+func TestAuth(t *testing.T) {
+testLoop:
+ for i, test := range authTests {
+ name, resp, err := test.auth.Start(&ServerInfo{"testserver", true, nil})
+ if name != test.name {
+ t.Errorf("#%d got name %s, expected %s", i, name, test.name)
+ }
+ if !bytes.Equal(resp, []byte(test.responses[0])) {
+ t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0])
+ }
+ if err != nil {
+ t.Errorf("#%d error: %s", i, err)
+ }
+ for j := range test.challenges {
+ challenge := []byte(test.challenges[j])
+ expected := []byte(test.responses[j+1])
+ resp, err := test.auth.Next(challenge, true)
+ if err != nil {
+ t.Errorf("#%d error: %s", i, err)
+ continue testLoop
+ }
+ if !bytes.Equal(resp, expected) {
+ t.Errorf("#%d got %s, expected %s", i, resp, expected)
+ continue testLoop
+ }
+ }
+ }
+}
+
+type faker struct {
+ io.ReadWriter
+}
+
+func (f faker) Close() error {
+ return nil
+}
+
+func TestBasic(t *testing.T) {
+ basicServer = strings.Join(strings.Split(basicServer, "\n"), "\r\n")
+ basicClient = strings.Join(strings.Split(basicClient, "\n"), "\r\n")
+
+ var cmdbuf bytes.Buffer
+ bcmdbuf := bufio.NewWriter(&cmdbuf)
+ var fake faker
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(basicServer)), bcmdbuf)
+ c := &Client{Text: textproto.NewConn(fake)}
+
+ if err := c.helo(); err != nil {
+ t.Fatalf("HELO failed: %s", err)
+ }
+ if err := c.ehlo(); err == nil {
+ t.Fatalf("Expected first EHLO to fail")
+ }
+ if err := c.ehlo(); err != nil {
+ t.Fatalf("Second EHLO failed: %s", err)
+ }
+
+ if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" {
+ t.Fatalf("Expected AUTH supported")
+ }
+ if ok, _ := c.Extension("DSN"); ok {
+ t.Fatalf("Shouldn't support DSN")
+ }
+
+ if err := c.Mail("user@gmail.com"); err == nil {
+ t.Fatalf("MAIL should require authentication")
+ }
+
+ if err := c.Verify("user1@gmail.com"); err == nil {
+ t.Fatalf("First VRFY: expected no verification")
+ }
+ if err := c.Verify("user2@gmail.com"); err != nil {
+ t.Fatalf("Second VRFY: expected verification, got %s", err)
+ }
+
+ // fake TLS so authentication won't complain
+ c.tls = true
+ c.serverName = "smtp.google.com"
+ if err := c.Auth(PlainAuth("", "user", "pass", "smtp.google.com")); err != nil {
+ t.Fatalf("AUTH failed: %s", err)
+ }
+
+ if err := c.Mail("user@gmail.com"); err != nil {
+ t.Fatalf("MAIL failed: %s", err)
+ }
+ if err := c.Rcpt("golang-nuts@googlegroups.com"); err != nil {
+ t.Fatalf("RCPT failed: %s", err)
+ }
+ msg := `From: user@gmail.com
+To: golang-nuts@googlegroups.com
+Subject: Hooray for Go
+
+Line 1
+.Leading dot line .
+Goodbye.`
+ w, err := c.Data()
+ if err != nil {
+ t.Fatalf("DATA failed: %s", err)
+ }
+ if _, err := w.Write([]byte(msg)); err != nil {
+ t.Fatalf("Data write failed: %s", err)
+ }
+ if err := w.Close(); err != nil {
+ t.Fatalf("Bad data response: %s", err)
+ }
+
+ if err := c.Quit(); err != nil {
+ t.Fatalf("QUIT failed: %s", err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ if basicClient != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, basicClient)
+ }
+}
+
+var basicServer = `250 mx.google.com at your service
+502 Unrecognized command.
+250-mx.google.com at your service
+250-SIZE 35651584
+250-AUTH LOGIN PLAIN
+250 8BITMIME
+530 Authentication required
+252 Send some mail, I'll try my best
+250 User is valid
+235 Accepted
+250 Sender OK
+250 Receiver OK
+354 Go ahead
+250 Data OK
+221 OK
+`
+
+var basicClient = `HELO localhost
+EHLO localhost
+EHLO localhost
+MAIL FROM:<user@gmail.com> BODY=8BITMIME
+VRFY user1@gmail.com
+VRFY user2@gmail.com
+AUTH PLAIN AHVzZXIAcGFzcw==
+MAIL FROM:<user@gmail.com> BODY=8BITMIME
+RCPT TO:<golang-nuts@googlegroups.com>
+DATA
+From: user@gmail.com
+To: golang-nuts@googlegroups.com
+Subject: Hooray for Go
+
+Line 1
+..Leading dot line .
+Goodbye.
+.
+QUIT
+`
diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go
index 366e050ff..867e328f1 100644
--- a/src/pkg/net/sock.go
+++ b/src/pkg/net/sock.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd windows
+// +build darwin freebsd linux netbsd openbsd windows
// Sockets
@@ -10,51 +10,46 @@ package net
import (
"io"
- "os"
"reflect"
"syscall"
)
-// Boolean to int.
-func boolint(b bool) int {
- if b {
- return 1
- }
- return 0
-}
+var listenerBacklog = maxListenerBacklog()
// Generic socket creation.
-func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err os.Error) {
+func socket(net string, f, t, p int, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
// See ../syscall/exec.go for description of ForkLock.
syscall.ForkLock.RLock()
- s, e := syscall.Socket(f, p, t)
- if e != 0 {
+ s, err := syscall.Socket(f, t, p)
+ if err != nil {
syscall.ForkLock.RUnlock()
- return nil, os.Errno(e)
+ return nil, err
}
syscall.CloseOnExec(s)
syscall.ForkLock.RUnlock()
- setKernelSpecificSockopt(s, f)
+ setDefaultSockopts(s, f, t)
if la != nil {
- e = syscall.Bind(s, la)
- if e != 0 {
+ err = syscall.Bind(s, la)
+ if err != nil {
closesocket(s)
- return nil, os.Errno(e)
+ return nil, err
}
}
- if fd, err = newFD(s, f, p, net); err != nil {
+ if fd, err = newFD(s, f, t, net); err != nil {
closesocket(s)
return nil, err
}
if ra != nil {
if err = fd.connect(ra); err != nil {
+ closesocket(s)
fd.Close()
return nil, err
}
+ fd.isConnected = true
}
sa, _ := syscall.Getsockname(s)
@@ -66,93 +61,11 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal
return fd, nil
}
-func setsockoptInt(fd *netFD, level, opt int, value int) os.Error {
- return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, level, opt, value))
-}
-
-func setsockoptNsec(fd *netFD, level, opt int, nsec int64) os.Error {
- var tv = syscall.NsecToTimeval(nsec)
- return os.NewSyscallError("setsockopt", syscall.SetsockoptTimeval(fd.sysfd, level, opt, &tv))
-}
-
-func setReadBuffer(fd *netFD, bytes int) os.Error {
- fd.incref()
- defer fd.decref()
- return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes)
-}
-
-func setWriteBuffer(fd *netFD, bytes int) os.Error {
- fd.incref()
- defer fd.decref()
- return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)
-}
-
-func setReadTimeout(fd *netFD, nsec int64) os.Error {
- fd.rdeadline_delta = nsec
- return nil
-}
-
-func setWriteTimeout(fd *netFD, nsec int64) os.Error {
- fd.wdeadline_delta = nsec
- return nil
-}
-
-func setTimeout(fd *netFD, nsec int64) os.Error {
- if e := setReadTimeout(fd, nsec); e != nil {
- return e
- }
- return setWriteTimeout(fd, nsec)
-}
-
-func setReuseAddr(fd *netFD, reuse bool) os.Error {
- fd.incref()
- defer fd.decref()
- return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse))
-}
-
-func bindToDevice(fd *netFD, dev string) os.Error {
- // TODO(rsc): call setsockopt with null-terminated string pointer
- return os.EINVAL
-}
-
-func setDontRoute(fd *netFD, dontroute bool) os.Error {
- fd.incref()
- defer fd.decref()
- return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute))
-}
-
-func setKeepAlive(fd *netFD, keepalive bool) os.Error {
- fd.incref()
- defer fd.decref()
- return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive))
-}
-
-func setNoDelay(fd *netFD, noDelay bool) os.Error {
- fd.incref()
- defer fd.decref()
- return setsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay))
-}
-
-func setLinger(fd *netFD, sec int) os.Error {
- var l syscall.Linger
- if sec >= 0 {
- l.Onoff = 1
- l.Linger = int32(sec)
- } else {
- l.Onoff = 0
- l.Linger = 0
- }
- fd.incref()
- defer fd.decref()
- e := syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l)
- return os.NewSyscallError("setsockopt", e)
-}
-
type UnknownSocketError struct {
sa syscall.Sockaddr
}
-func (e *UnknownSocketError) String() string {
+func (e *UnknownSocketError) Error() string {
return "unknown socket address type " + reflect.TypeOf(e.sa).String()
}
@@ -162,7 +75,7 @@ type writerOnly struct {
// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
// applicable.
-func genericReadFrom(w io.Writer, r io.Reader) (n int64, err os.Error) {
+func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
// Use wrapper to hide existing r.ReadFrom from io.Copy.
return io.Copy(writerOnly{w}, r)
}
diff --git a/src/pkg/net/sock_bsd.go b/src/pkg/net/sock_bsd.go
index c59802fec..630a91ed9 100644
--- a/src/pkg/net/sock_bsd.go
+++ b/src/pkg/net/sock_bsd.go
@@ -1,33 +1,33 @@
-// Copyright 2011 The Go Authors. All rights reserved.
+// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd
+// +build darwin freebsd netbsd openbsd
// Sockets for BSD variants
package net
import (
+ "runtime"
"syscall"
)
-func setKernelSpecificSockopt(s, f int) {
- // Allow reuse of recently-used addresses.
- syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
-
- // Allow reuse of recently-used ports.
- // This option is supported only in descendants of 4.4BSD,
- // to make an effective multicast application and an application
- // that requires quick draw possible.
- syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1)
-
- // Allow broadcast.
- syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
-
- if f == syscall.AF_INET6 {
- // using ip, tcp, udp, etc.
- // allow both protocols even if the OS default is otherwise.
- syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+func maxListenerBacklog() int {
+ var (
+ n uint32
+ err error
+ )
+ switch runtime.GOOS {
+ case "darwin", "freebsd":
+ n, err = syscall.SysctlUint32("kern.ipc.somaxconn")
+ case "netbsd":
+ // NOTE: NetBSD has no somaxconn-like kernel state so far
+ case "openbsd":
+ n, err = syscall.SysctlUint32("kern.somaxconn")
+ }
+ if n == 0 || err != nil {
+ return syscall.SOMAXCONN
}
+ return int(n)
}
diff --git a/src/pkg/net/sock_linux.go b/src/pkg/net/sock_linux.go
index ec31e803b..2cbc34f24 100644
--- a/src/pkg/net/sock_linux.go
+++ b/src/pkg/net/sock_linux.go
@@ -1,4 +1,4 @@
-// Copyright 2011 The Go Authors. All rights reserved.
+// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@@ -6,20 +6,22 @@
package net
-import (
- "syscall"
-)
+import "syscall"
-func setKernelSpecificSockopt(s, f int) {
- // Allow reuse of recently-used addresses.
- syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
-
- // Allow broadcast.
- syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
-
- if f == syscall.AF_INET6 {
- // using ip, tcp, udp, etc.
- // allow both protocols even if the OS default is otherwise.
- syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+func maxListenerBacklog() int {
+ fd, err := open("/proc/sys/net/core/somaxconn")
+ if err != nil {
+ return syscall.SOMAXCONN
+ }
+ defer fd.close()
+ l, ok := fd.readLine()
+ if !ok {
+ return syscall.SOMAXCONN
+ }
+ f := getFields(l)
+ n, _, ok := dtoi(f[0], 0)
+ if n == 0 || !ok {
+ return syscall.SOMAXCONN
}
+ return n
}
diff --git a/src/pkg/net/sock_windows.go b/src/pkg/net/sock_windows.go
index c6dbd0465..2d803de1f 100644
--- a/src/pkg/net/sock_windows.go
+++ b/src/pkg/net/sock_windows.go
@@ -1,4 +1,4 @@
-// Copyright 2011 The Go Authors. All rights reserved.
+// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@@ -6,20 +6,9 @@
package net
-import (
- "syscall"
-)
+import "syscall"
-func setKernelSpecificSockopt(s syscall.Handle, f int) {
- // Allow reuse of recently-used addresses and ports.
- syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
-
- // Allow broadcast.
- syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
-
- if f == syscall.AF_INET6 {
- // using ip, tcp, udp, etc.
- // allow both protocols even if the OS default is otherwise.
- syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
- }
+func maxListenerBacklog() int {
+ // TODO: Implement this
+ return syscall.SOMAXCONN
}
diff --git a/src/pkg/net/sockopt.go b/src/pkg/net/sockopt.go
new file mode 100644
index 000000000..3d0f8dd7a
--- /dev/null
+++ b/src/pkg/net/sockopt.go
@@ -0,0 +1,180 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd windows
+
+// Socket options
+
+package net
+
+import (
+ "bytes"
+ "os"
+ "syscall"
+ "time"
+)
+
+// Boolean to int.
+func boolint(b bool) int {
+ if b {
+ return 1
+ }
+ return 0
+}
+
+func ipv4AddrToInterface(ip IP) (*Interface, error) {
+ ift, err := Interfaces()
+ if err != nil {
+ return nil, err
+ }
+ for _, ifi := range ift {
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ for _, ifa := range ifat {
+ switch v := ifa.(type) {
+ case *IPAddr:
+ if ip.Equal(v.IP) {
+ return &ifi, nil
+ }
+ case *IPNet:
+ if ip.Equal(v.IP) {
+ return &ifi, nil
+ }
+ }
+ }
+ }
+ if ip.Equal(IPv4zero) {
+ return nil, nil
+ }
+ return nil, errNoSuchInterface
+}
+
+func interfaceToIPv4Addr(ifi *Interface) (IP, error) {
+ if ifi == nil {
+ return IPv4zero, nil
+ }
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ for _, ifa := range ifat {
+ switch v := ifa.(type) {
+ case *IPAddr:
+ if v.IP.To4() != nil {
+ return v.IP, nil
+ }
+ case *IPNet:
+ if v.IP.To4() != nil {
+ return v.IP, nil
+ }
+ }
+ }
+ return nil, errNoSuchInterface
+}
+
+func setIPv4MreqToInterface(mreq *syscall.IPMreq, ifi *Interface) error {
+ if ifi == nil {
+ return nil
+ }
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ return err
+ }
+ for _, ifa := range ifat {
+ switch v := ifa.(type) {
+ case *IPAddr:
+ if a := v.IP.To4(); a != nil {
+ copy(mreq.Interface[:], a)
+ goto done
+ }
+ case *IPNet:
+ if a := v.IP.To4(); a != nil {
+ copy(mreq.Interface[:], a)
+ goto done
+ }
+ }
+ }
+done:
+ if bytes.Equal(mreq.Multiaddr[:], IPv4zero.To4()) {
+ return errNoSuchMulticastInterface
+ }
+ return nil
+}
+
+func setReadBuffer(fd *netFD, bytes int) error {
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes))
+}
+
+func setWriteBuffer(fd *netFD, bytes int) error {
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes))
+}
+
+func setReadDeadline(fd *netFD, t time.Time) error {
+ if t.IsZero() {
+ fd.rdeadline = 0
+ } else {
+ fd.rdeadline = t.UnixNano()
+ }
+ return nil
+}
+
+func setWriteDeadline(fd *netFD, t time.Time) error {
+ if t.IsZero() {
+ fd.wdeadline = 0
+ } else {
+ fd.wdeadline = t.UnixNano()
+ }
+ return nil
+}
+
+func setDeadline(fd *netFD, t time.Time) error {
+ if e := setReadDeadline(fd, t); e != nil {
+ return e
+ }
+ return setWriteDeadline(fd, t)
+}
+
+func setReuseAddr(fd *netFD, reuse bool) error {
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse)))
+}
+
+func setDontRoute(fd *netFD, dontroute bool) error {
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute)))
+}
+
+func setKeepAlive(fd *netFD, keepalive bool) error {
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive)))
+}
+
+func setNoDelay(fd *netFD, noDelay bool) error {
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay)))
+}
+
+func setLinger(fd *netFD, sec int) error {
+ var l syscall.Linger
+ if sec >= 0 {
+ l.Onoff = 1
+ l.Linger = int32(sec)
+ } else {
+ l.Onoff = 0
+ l.Linger = 0
+ }
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l))
+}
diff --git a/src/pkg/net/sockopt_bsd.go b/src/pkg/net/sockopt_bsd.go
new file mode 100644
index 000000000..2093e0812
--- /dev/null
+++ b/src/pkg/net/sockopt_bsd.go
@@ -0,0 +1,45 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd netbsd openbsd
+
+// Socket options for BSD variants
+
+package net
+
+import (
+ "syscall"
+)
+
+func setDefaultSockopts(s, f, t int) {
+ switch f {
+ case syscall.AF_INET6:
+ // Allow both IP versions even if the OS default is otherwise.
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+ }
+
+ if f == syscall.AF_UNIX ||
+ (f == syscall.AF_INET || f == syscall.AF_INET6) && t == syscall.SOCK_STREAM {
+ // Allow reuse of recently-used addresses.
+ syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+
+ // Allow reuse of recently-used ports.
+ // This option is supported only in descendants of 4.4BSD,
+ // to make an effective multicast application and an application
+ // that requires quick draw possible.
+ syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1)
+ }
+
+ // Allow broadcast.
+ syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
+}
+
+func setDefaultMulticastSockopts(fd *netFD) {
+ fd.incref()
+ defer fd.decref()
+ // Allow multicast UDP and raw IP datagram sockets to listen
+ // concurrently across multiple listeners.
+ syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+ syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1)
+}
diff --git a/src/pkg/net/sockopt_linux.go b/src/pkg/net/sockopt_linux.go
new file mode 100644
index 000000000..9dbb4e5dd
--- /dev/null
+++ b/src/pkg/net/sockopt_linux.go
@@ -0,0 +1,37 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Socket options for Linux
+
+package net
+
+import (
+ "syscall"
+)
+
+func setDefaultSockopts(s, f, t int) {
+ switch f {
+ case syscall.AF_INET6:
+ // Allow both IP versions even if the OS default is otherwise.
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+ }
+
+ if f == syscall.AF_UNIX ||
+ (f == syscall.AF_INET || f == syscall.AF_INET6) && t == syscall.SOCK_STREAM {
+ // Allow reuse of recently-used addresses.
+ syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+ }
+
+ // Allow broadcast.
+ syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
+
+}
+
+func setDefaultMulticastSockopts(fd *netFD) {
+ fd.incref()
+ defer fd.decref()
+ // Allow multicast UDP and raw IP datagram sockets to listen
+ // concurrently across multiple listeners.
+ syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+}
diff --git a/src/pkg/net/sockopt_windows.go b/src/pkg/net/sockopt_windows.go
new file mode 100644
index 000000000..a7b5606d8
--- /dev/null
+++ b/src/pkg/net/sockopt_windows.go
@@ -0,0 +1,38 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Socket options for Windows
+
+package net
+
+import (
+ "syscall"
+)
+
+func setDefaultSockopts(s syscall.Handle, f, t int) {
+ switch f {
+ case syscall.AF_INET6:
+ // Allow both IP versions even if the OS default is otherwise.
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+ }
+
+ // Windows will reuse recently-used addresses by default.
+ // SO_REUSEADDR should not be used here, as it allows
+ // a socket to forcibly bind to a port in use by another socket.
+ // This could lead to a non-deterministic behavior, where
+ // connection requests over the port cannot be guaranteed
+ // to be handled by the correct socket.
+
+ // Allow broadcast.
+ syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
+
+}
+
+func setDefaultMulticastSockopts(fd *netFD) {
+ fd.incref()
+ defer fd.decref()
+ // Allow multicast UDP and raw IP datagram sockets to listen
+ // concurrently across multiple listeners.
+ syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+}
diff --git a/src/pkg/net/sockoptip.go b/src/pkg/net/sockoptip.go
new file mode 100644
index 000000000..90b6f751e
--- /dev/null
+++ b/src/pkg/net/sockoptip.go
@@ -0,0 +1,187 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd windows
+
+// IP-level socket options
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func ipv4TOS(fd *netFD) (int, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS)
+ if err != nil {
+ return -1, os.NewSyscallError("getsockopt", err)
+ }
+ return v, nil
+}
+
+func setIPv4TOS(fd *netFD, v int) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS, v)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4TTL(fd *netFD) (int, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL)
+ if err != nil {
+ return -1, os.NewSyscallError("getsockopt", err)
+ }
+ return v, nil
+}
+
+func setIPv4TTL(fd *netFD, v int) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL, v)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
+ mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
+ if err := setIPv4MreqToInterface(mreq, ifi); err != nil {
+ return err
+ }
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq))
+}
+
+func leaveIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
+ mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
+ if err := setIPv4MreqToInterface(mreq, ifi); err != nil {
+ return err
+ }
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_DROP_MEMBERSHIP, mreq))
+}
+
+func ipv6HopLimit(fd *netFD) (int, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS)
+ if err != nil {
+ return -1, os.NewSyscallError("getsockopt", err)
+ }
+ return v, nil
+}
+
+func setIPv6HopLimit(fd *netFD, v int) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, v)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv6MulticastInterface(fd *netFD) (*Interface, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF)
+ if err != nil {
+ return nil, os.NewSyscallError("getsockopt", err)
+ }
+ if v == 0 {
+ return nil, nil
+ }
+ ifi, err := InterfaceByIndex(v)
+ if err != nil {
+ return nil, err
+ }
+ return ifi, nil
+}
+
+func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error {
+ var v int
+ if ifi != nil {
+ v = ifi.Index
+ }
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv6MulticastHopLimit(fd *netFD) (int, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS)
+ if err != nil {
+ return -1, os.NewSyscallError("getsockopt", err)
+ }
+ return v, nil
+}
+
+func setIPv6MulticastHopLimit(fd *netFD, v int) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS, v)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv6MulticastLoopback(fd *netFD) (bool, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP)
+ if err != nil {
+ return false, os.NewSyscallError("getsockopt", err)
+ }
+ return v == 1, nil
+}
+
+func setIPv6MulticastLoopback(fd *netFD, v bool) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
+ mreq := &syscall.IPv6Mreq{}
+ copy(mreq.Multiaddr[:], ip)
+ if ifi != nil {
+ mreq.Interface = uint32(ifi.Index)
+ }
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq))
+}
+
+func leaveIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
+ mreq := &syscall.IPv6Mreq{}
+ copy(mreq.Multiaddr[:], ip)
+ if ifi != nil {
+ mreq.Interface = uint32(ifi.Index)
+ }
+ fd.incref()
+ defer fd.decref()
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_LEAVE_GROUP, mreq))
+}
diff --git a/src/pkg/net/sockoptip_bsd.go b/src/pkg/net/sockoptip_bsd.go
new file mode 100644
index 000000000..5f7dff248
--- /dev/null
+++ b/src/pkg/net/sockoptip_bsd.go
@@ -0,0 +1,54 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd netbsd openbsd
+
+// IP-level socket options for BSD variants
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func ipv4MulticastTTL(fd *netFD) (int, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL)
+ if err != nil {
+ return -1, os.NewSyscallError("getsockopt", err)
+ }
+ return int(v), nil
+}
+
+func setIPv4MulticastTTL(fd *netFD, v int) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, byte(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv6TrafficClass(fd *netFD) (int, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS)
+ if err != nil {
+ return -1, os.NewSyscallError("getsockopt", err)
+ }
+ return v, nil
+}
+
+func setIPv6TrafficClass(fd *netFD, v int) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
diff --git a/src/pkg/net/sockoptip_darwin.go b/src/pkg/net/sockoptip_darwin.go
new file mode 100644
index 000000000..dedfd6f4c
--- /dev/null
+++ b/src/pkg/net/sockoptip_darwin.go
@@ -0,0 +1,78 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for Darwin
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+ fd.incref()
+ defer fd.decref()
+ a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
+ if err != nil {
+ return nil, os.NewSyscallError("getsockopt", err)
+ }
+ return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3]))
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ ip, err := interfaceToIPv4Addr(ifi)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ var x [4]byte
+ copy(x[:], ip.To4())
+ fd.incref()
+ defer fd.decref()
+ err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
+ if err != nil {
+ return false, os.NewSyscallError("getsockopt", err)
+ }
+ return v == 1, nil
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF)
+ if err != nil {
+ return false, os.NewSyscallError("getsockopt", err)
+ }
+ return v == 1, nil
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
diff --git a/src/pkg/net/sockoptip_freebsd.go b/src/pkg/net/sockoptip_freebsd.go
new file mode 100644
index 000000000..55f7b1a60
--- /dev/null
+++ b/src/pkg/net/sockoptip_freebsd.go
@@ -0,0 +1,80 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for FreeBSD
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+ fd.incref()
+ defer fd.decref()
+ mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
+ if err != nil {
+ return nil, os.NewSyscallError("getsockopt", err)
+ }
+ if int(mreq.Ifindex) == 0 {
+ return nil, nil
+ }
+ return InterfaceByIndex(int(mreq.Ifindex))
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ var v int32
+ if ifi != nil {
+ v = int32(ifi.Index)
+ }
+ mreq := &syscall.IPMreqn{Ifindex: v}
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
+ if err != nil {
+ return false, os.NewSyscallError("getsockopt", err)
+ }
+ return v == 1, nil
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF)
+ if err != nil {
+ return false, os.NewSyscallError("getsockopt", err)
+ }
+ return v == 1, nil
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
diff --git a/src/pkg/net/sockoptip_linux.go b/src/pkg/net/sockoptip_linux.go
new file mode 100644
index 000000000..360f8dea6
--- /dev/null
+++ b/src/pkg/net/sockoptip_linux.go
@@ -0,0 +1,120 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for Linux
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+ fd.incref()
+ defer fd.decref()
+ mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
+ if err != nil {
+ return nil, os.NewSyscallError("getsockopt", err)
+ }
+ if int(mreq.Ifindex) == 0 {
+ return nil, nil
+ }
+ return InterfaceByIndex(int(mreq.Ifindex))
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ var v int32
+ if ifi != nil {
+ v = int32(ifi.Index)
+ }
+ mreq := &syscall.IPMreqn{Ifindex: v}
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4MulticastTTL(fd *netFD) (int, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL)
+ if err != nil {
+ return -1, os.NewSyscallError("getsockopt", err)
+ }
+ return v, nil
+}
+
+func setIPv4MulticastTTL(fd *netFD, v int) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, v)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
+ if err != nil {
+ return false, os.NewSyscallError("getsockopt", err)
+ }
+ return v == 1, nil
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO)
+ if err != nil {
+ return false, os.NewSyscallError("getsockopt", err)
+ }
+ return v == 1, nil
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO, boolint(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv6TrafficClass(fd *netFD) (int, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS)
+ if err != nil {
+ return -1, os.NewSyscallError("getsockopt", err)
+ }
+ return v, nil
+}
+
+func setIPv6TrafficClass(fd *netFD, v int) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
diff --git a/src/pkg/net/sockoptip_openbsd.go b/src/pkg/net/sockoptip_openbsd.go
new file mode 100644
index 000000000..89b8e4592
--- /dev/null
+++ b/src/pkg/net/sockoptip_openbsd.go
@@ -0,0 +1,78 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for OpenBSD
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+ fd.incref()
+ defer fd.decref()
+ a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
+ if err != nil {
+ return nil, os.NewSyscallError("getsockopt", err)
+ }
+ return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3]))
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ ip, err := interfaceToIPv4Addr(ifi)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ var x [4]byte
+ copy(x[:], ip.To4())
+ fd.incref()
+ defer fd.decref()
+ err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
+ if err != nil {
+ return false, os.NewSyscallError("getsockopt", err)
+ }
+ return v == 1, nil
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v)))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+ fd.incref()
+ defer fd.decref()
+ v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF)
+ if err != nil {
+ return false, os.NewSyscallError("getsockopt", err)
+ }
+ return v == 1, nil
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+ fd.incref()
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
diff --git a/src/pkg/net/sockoptip_windows.go b/src/pkg/net/sockoptip_windows.go
new file mode 100644
index 000000000..3320e76bd
--- /dev/null
+++ b/src/pkg/net/sockoptip_windows.go
@@ -0,0 +1,61 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for Windows
+
+package net
+
+import (
+ "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+ // TODO: Implement this
+ return nil, syscall.EWINDOWS
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ // TODO: Implement this
+ return syscall.EWINDOWS
+}
+
+func ipv4MulticastTTL(fd *netFD) (int, error) {
+ // TODO: Implement this
+ return -1, syscall.EWINDOWS
+}
+
+func setIPv4MulticastTTL(fd *netFD, v int) error {
+ // TODO: Implement this
+ return syscall.EWINDOWS
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+ // TODO: Implement this
+ return false, syscall.EWINDOWS
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+ // TODO: Implement this
+ return syscall.EWINDOWS
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+ // TODO: Implement this
+ return false, syscall.EWINDOWS
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+ // TODO: Implement this
+ return syscall.EWINDOWS
+}
+
+func ipv6TrafficClass(fd *netFD) (int, error) {
+ // TODO: Implement this
+ return 0, syscall.EWINDOWS
+}
+
+func setIPv6TrafficClass(fd *netFD, v int) error {
+ // TODO: Implement this
+ return syscall.EWINDOWS
+}
diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go
index f5c0a2781..47fbf2919 100644
--- a/src/pkg/net/tcpsock.go
+++ b/src/pkg/net/tcpsock.go
@@ -6,10 +6,6 @@
package net
-import (
- "os"
-)
-
// TCPAddr represents the address of a TCP end point.
type TCPAddr struct {
IP IP
@@ -31,7 +27,7 @@ func (a *TCPAddr) String() string {
// numeric addresses on the network net, which must be "tcp",
// "tcp4" or "tcp6". A literal IPv6 host address must be
// enclosed in square brackets, as in "[::]:80".
-func ResolveTCPAddr(net, addr string) (*TCPAddr, os.Error) {
+func ResolveTCPAddr(net, addr string) (*TCPAddr, error) {
ip, port, err := hostPortToIP(net, addr)
if err != nil {
return nil, err
diff --git a/src/pkg/net/tcpsock_plan9.go b/src/pkg/net/tcpsock_plan9.go
index f4f6e9fee..f2444a4d9 100644
--- a/src/pkg/net/tcpsock_plan9.go
+++ b/src/pkg/net/tcpsock_plan9.go
@@ -8,6 +8,7 @@ package net
import (
"os"
+ "time"
)
// TCPConn is an implementation of the Conn interface
@@ -16,17 +17,50 @@ type TCPConn struct {
plan9Conn
}
+// SetDeadline implements the net.Conn SetDeadline method.
+func (c *TCPConn) SetDeadline(t time.Time) error {
+ return os.EPLAN9
+}
+
+// SetReadDeadline implements the net.Conn SetReadDeadline method.
+func (c *TCPConn) SetReadDeadline(t time.Time) error {
+ return os.EPLAN9
+}
+
+// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
+func (c *TCPConn) SetWriteDeadline(t time.Time) error {
+ return os.EPLAN9
+}
+
+// CloseRead shuts down the reading side of the TCP connection.
+// Most callers should just use Close.
+func (c *TCPConn) CloseRead() error {
+ if !c.ok() {
+ return os.EINVAL
+ }
+ return os.EPLAN9
+}
+
+// CloseWrite shuts down the writing side of the TCP connection.
+// Most callers should just use Close.
+func (c *TCPConn) CloseWrite() error {
+ if !c.ok() {
+ return os.EINVAL
+ }
+ return os.EPLAN9
+}
+
// DialTCP connects to the remote address raddr on the network net,
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
// as the local address for the connection.
-func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err os.Error) {
+func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err error) {
switch net {
case "tcp", "tcp4", "tcp6":
default:
return nil, UnknownNetworkError(net)
}
if raddr == nil {
- return nil, &OpError{"dial", "tcp", nil, errMissingAddress}
+ return nil, &OpError{"dial", net, nil, errMissingAddress}
}
c1, err := dialPlan9(net, laddr, raddr)
if err != nil {
@@ -46,14 +80,14 @@ type TCPListener struct {
// Net must be "tcp", "tcp4", or "tcp6".
// If laddr has a port of 0, it means to listen on some available port.
// The caller can use l.Addr() to retrieve the chosen address.
-func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) {
+func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err error) {
switch net {
case "tcp", "tcp4", "tcp6":
default:
return nil, UnknownNetworkError(net)
}
if laddr == nil {
- return nil, &OpError{"listen", "tcp", nil, errMissingAddress}
+ return nil, &OpError{"listen", net, nil, errMissingAddress}
}
l1, err := listenPlan9(net, laddr)
if err != nil {
diff --git a/src/pkg/net/tcpsock_posix.go b/src/pkg/net/tcpsock_posix.go
index 35d536c31..65ec49303 100644
--- a/src/pkg/net/tcpsock_posix.go
+++ b/src/pkg/net/tcpsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd windows
+// +build darwin freebsd linux netbsd openbsd windows
// TCP sockets
@@ -12,6 +12,7 @@ import (
"io"
"os"
"syscall"
+ "time"
)
// BUG(rsc): On OpenBSD, listening on the "tcp" network does not listen for
@@ -39,7 +40,7 @@ func (a *TCPAddr) family() int {
return syscall.AF_INET6
}
-func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) {
+func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
return ipToSockaddr(family, a.IP, a.Port)
}
@@ -67,7 +68,7 @@ func (c *TCPConn) ok() bool { return c != nil && c.fd != nil }
// Implementation of the Conn interface - see Conn for documentation.
// Read implements the net.Conn Read method.
-func (c *TCPConn) Read(b []byte) (n int, err os.Error) {
+func (c *TCPConn) Read(b []byte) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -75,7 +76,7 @@ func (c *TCPConn) Read(b []byte) (n int, err os.Error) {
}
// ReadFrom implements the io.ReaderFrom ReadFrom method.
-func (c *TCPConn) ReadFrom(r io.Reader) (int64, os.Error) {
+func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
if n, err, handled := sendFile(c.fd, r); handled {
return n, err
}
@@ -83,7 +84,7 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, os.Error) {
}
// Write implements the net.Conn Write method.
-func (c *TCPConn) Write(b []byte) (n int, err os.Error) {
+func (c *TCPConn) Write(b []byte) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -91,7 +92,7 @@ func (c *TCPConn) Write(b []byte) (n int, err os.Error) {
}
// Close closes the TCP connection.
-func (c *TCPConn) Close() os.Error {
+func (c *TCPConn) Close() error {
if !c.ok() {
return os.EINVAL
}
@@ -100,6 +101,24 @@ func (c *TCPConn) Close() os.Error {
return err
}
+// CloseRead shuts down the reading side of the TCP connection.
+// Most callers should just use Close.
+func (c *TCPConn) CloseRead() error {
+ if !c.ok() {
+ return os.EINVAL
+ }
+ return c.fd.CloseRead()
+}
+
+// CloseWrite shuts down the writing side of the TCP connection.
+// Most callers should just use Close.
+func (c *TCPConn) CloseWrite() error {
+ if !c.ok() {
+ return os.EINVAL
+ }
+ return c.fd.CloseWrite()
+}
+
// LocalAddr returns the local network address, a *TCPAddr.
func (c *TCPConn) LocalAddr() Addr {
if !c.ok() {
@@ -116,33 +135,33 @@ func (c *TCPConn) RemoteAddr() Addr {
return c.fd.raddr
}
-// SetTimeout implements the net.Conn SetTimeout method.
-func (c *TCPConn) SetTimeout(nsec int64) os.Error {
+// SetDeadline implements the net.Conn SetDeadline method.
+func (c *TCPConn) SetDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setTimeout(c.fd, nsec)
+ return setDeadline(c.fd, t)
}
-// SetReadTimeout implements the net.Conn SetReadTimeout method.
-func (c *TCPConn) SetReadTimeout(nsec int64) os.Error {
+// SetReadDeadline implements the net.Conn SetReadDeadline method.
+func (c *TCPConn) SetReadDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setReadTimeout(c.fd, nsec)
+ return setReadDeadline(c.fd, t)
}
-// SetWriteTimeout implements the net.Conn SetWriteTimeout method.
-func (c *TCPConn) SetWriteTimeout(nsec int64) os.Error {
+// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
+func (c *TCPConn) SetWriteDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setWriteTimeout(c.fd, nsec)
+ return setWriteDeadline(c.fd, t)
}
// SetReadBuffer sets the size of the operating system's
// receive buffer associated with the connection.
-func (c *TCPConn) SetReadBuffer(bytes int) os.Error {
+func (c *TCPConn) SetReadBuffer(bytes int) error {
if !c.ok() {
return os.EINVAL
}
@@ -151,7 +170,7 @@ func (c *TCPConn) SetReadBuffer(bytes int) os.Error {
// SetWriteBuffer sets the size of the operating system's
// transmit buffer associated with the connection.
-func (c *TCPConn) SetWriteBuffer(bytes int) os.Error {
+func (c *TCPConn) SetWriteBuffer(bytes int) error {
if !c.ok() {
return os.EINVAL
}
@@ -169,7 +188,7 @@ func (c *TCPConn) SetWriteBuffer(bytes int) os.Error {
//
// If sec > 0, Close blocks for at most sec seconds waiting for
// data to be sent and acknowledged.
-func (c *TCPConn) SetLinger(sec int) os.Error {
+func (c *TCPConn) SetLinger(sec int) error {
if !c.ok() {
return os.EINVAL
}
@@ -178,7 +197,7 @@ func (c *TCPConn) SetLinger(sec int) os.Error {
// SetKeepAlive sets whether the operating system should send
// keepalive messages on the connection.
-func (c *TCPConn) SetKeepAlive(keepalive bool) os.Error {
+func (c *TCPConn) SetKeepAlive(keepalive bool) error {
if !c.ok() {
return os.EINVAL
}
@@ -189,7 +208,7 @@ func (c *TCPConn) SetKeepAlive(keepalive bool) os.Error {
// packet transmission in hopes of sending fewer packets
// (Nagle's algorithm). The default is true (no delay), meaning
// that data is sent as soon as possible after a Write.
-func (c *TCPConn) SetNoDelay(noDelay bool) os.Error {
+func (c *TCPConn) SetNoDelay(noDelay bool) error {
if !c.ok() {
return os.EINVAL
}
@@ -199,14 +218,14 @@ func (c *TCPConn) SetNoDelay(noDelay bool) os.Error {
// File returns a copy of the underlying os.File, set to blocking mode.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
-func (c *TCPConn) File() (f *os.File, err os.Error) { return c.fd.dup() }
+func (c *TCPConn) File() (f *os.File, err error) { return c.fd.dup() }
// DialTCP connects to the remote address raddr on the network net,
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
// as the local address for the connection.
-func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err os.Error) {
+func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err error) {
if raddr == nil {
- return nil, &OpError{"dial", "tcp", nil, errMissingAddress}
+ return nil, &OpError{"dial", net, nil, errMissingAddress}
}
fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP)
if e != nil {
@@ -226,15 +245,15 @@ type TCPListener struct {
// Net must be "tcp", "tcp4", or "tcp6".
// If laddr has a port of 0, it means to listen on some available port.
// The caller can use l.Addr() to retrieve the chosen address.
-func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) {
+func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err error) {
fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP)
if err != nil {
return nil, err
}
- errno := syscall.Listen(fd.sysfd, listenBacklog())
- if errno != 0 {
+ err = syscall.Listen(fd.sysfd, listenerBacklog)
+ if err != nil {
closesocket(fd.sysfd)
- return nil, &OpError{"listen", "tcp", laddr, os.Errno(errno)}
+ return nil, &OpError{"listen", net, laddr, err}
}
l = new(TCPListener)
l.fd = fd
@@ -243,7 +262,7 @@ func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) {
// AcceptTCP accepts the next incoming call and returns the new connection
// and the remote address.
-func (l *TCPListener) AcceptTCP() (c *TCPConn, err os.Error) {
+func (l *TCPListener) AcceptTCP() (c *TCPConn, err error) {
if l == nil || l.fd == nil || l.fd.sysfd < 0 {
return nil, os.EINVAL
}
@@ -256,7 +275,7 @@ func (l *TCPListener) AcceptTCP() (c *TCPConn, err os.Error) {
// Accept implements the Accept method in the Listener interface;
// it waits for the next call and returns a generic Conn.
-func (l *TCPListener) Accept() (c Conn, err os.Error) {
+func (l *TCPListener) Accept() (c Conn, err error) {
c1, err := l.AcceptTCP()
if err != nil {
return nil, err
@@ -266,7 +285,7 @@ func (l *TCPListener) Accept() (c Conn, err os.Error) {
// Close stops listening on the TCP address.
// Already Accepted connections are not closed.
-func (l *TCPListener) Close() os.Error {
+func (l *TCPListener) Close() error {
if l == nil || l.fd == nil {
return os.EINVAL
}
@@ -276,15 +295,16 @@ func (l *TCPListener) Close() os.Error {
// Addr returns the listener's network address, a *TCPAddr.
func (l *TCPListener) Addr() Addr { return l.fd.laddr }
-// SetTimeout sets the deadline associated with the listener
-func (l *TCPListener) SetTimeout(nsec int64) os.Error {
+// SetDeadline sets the deadline associated with the listener.
+// A zero time value disables the deadline.
+func (l *TCPListener) SetDeadline(t time.Time) error {
if l == nil || l.fd == nil {
return os.EINVAL
}
- return setTimeout(l.fd, nsec)
+ return setDeadline(l.fd, t)
}
// File returns a copy of the underlying os.File, set to blocking mode.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
-func (l *TCPListener) File() (f *os.File, err os.Error) { return l.fd.dup() }
+func (l *TCPListener) File() (f *os.File, err error) { return l.fd.dup() }
diff --git a/src/pkg/net/textproto/header.go b/src/pkg/net/textproto/header.go
index 288deb2ce..7fb32f804 100644
--- a/src/pkg/net/textproto/header.go
+++ b/src/pkg/net/textproto/header.go
@@ -39,5 +39,5 @@ func (h MIMEHeader) Get(key string) string {
// Del deletes the values associated with key.
func (h MIMEHeader) Del(key string) {
- h[CanonicalMIMEHeaderKey(key)] = nil, false
+ delete(h, CanonicalMIMEHeaderKey(key))
}
diff --git a/src/pkg/net/textproto/pipeline.go b/src/pkg/net/textproto/pipeline.go
index 8c25884b3..ca50eddac 100644
--- a/src/pkg/net/textproto/pipeline.go
+++ b/src/pkg/net/textproto/pipeline.go
@@ -108,7 +108,7 @@ func (s *sequencer) End(id uint) {
}
c, ok := s.wait[id]
if ok {
- s.wait[id] = nil, false
+ delete(s.wait, id)
}
s.mu.Unlock()
if ok {
diff --git a/src/pkg/net/textproto/reader.go b/src/pkg/net/textproto/reader.go
index a404f4758..862cd536c 100644
--- a/src/pkg/net/textproto/reader.go
+++ b/src/pkg/net/textproto/reader.go
@@ -9,7 +9,6 @@ import (
"bytes"
"io"
"io/ioutil"
- "os"
"strconv"
"strings"
)
@@ -23,6 +22,7 @@ import (
type Reader struct {
R *bufio.Reader
dot *dotReader
+ buf []byte // a re-usable buffer for readContinuedLineSlice
}
// NewReader returns a new Reader reading from r.
@@ -32,13 +32,13 @@ func NewReader(r *bufio.Reader) *Reader {
// ReadLine reads a single line from r,
// eliding the final \n or \r\n from the returned string.
-func (r *Reader) ReadLine() (string, os.Error) {
+func (r *Reader) ReadLine() (string, error) {
line, err := r.readLineSlice()
return string(line), err
}
// ReadLineBytes is like ReadLine but returns a []byte instead of a string.
-func (r *Reader) ReadLineBytes() ([]byte, os.Error) {
+func (r *Reader) ReadLineBytes() ([]byte, error) {
line, err := r.readLineSlice()
if line != nil {
buf := make([]byte, len(line))
@@ -48,10 +48,24 @@ func (r *Reader) ReadLineBytes() ([]byte, os.Error) {
return line, err
}
-func (r *Reader) readLineSlice() ([]byte, os.Error) {
+func (r *Reader) readLineSlice() ([]byte, error) {
r.closeDot()
- line, _, err := r.R.ReadLine()
- return line, err
+ var line []byte
+ for {
+ l, more, err := r.R.ReadLine()
+ if err != nil {
+ return nil, err
+ }
+ // Avoid the copy if the first call produced a full line.
+ if line == nil && !more {
+ return l, nil
+ }
+ line = append(line, l...)
+ if !more {
+ break
+ }
+ }
+ return line, nil
}
// ReadContinuedLine reads a possibly continued line from r,
@@ -73,7 +87,7 @@ func (r *Reader) readLineSlice() ([]byte, os.Error) {
//
// A line consisting of only white space is never continued.
//
-func (r *Reader) ReadContinuedLine() (string, os.Error) {
+func (r *Reader) ReadContinuedLine() (string, error) {
line, err := r.readContinuedLineSlice()
return string(line), err
}
@@ -94,7 +108,7 @@ func trim(s []byte) []byte {
// ReadContinuedLineBytes is like ReadContinuedLine but
// returns a []byte instead of a string.
-func (r *Reader) ReadContinuedLineBytes() ([]byte, os.Error) {
+func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
line, err := r.readContinuedLineSlice()
if line != nil {
buf := make([]byte, len(line))
@@ -104,81 +118,51 @@ func (r *Reader) ReadContinuedLineBytes() ([]byte, os.Error) {
return line, err
}
-func (r *Reader) readContinuedLineSlice() ([]byte, os.Error) {
+func (r *Reader) readContinuedLineSlice() ([]byte, error) {
// Read the first line.
line, err := r.readLineSlice()
if err != nil {
- return line, err
+ return nil, err
}
if len(line) == 0 { // blank line - no continuation
return line, nil
}
- line = trim(line)
- copied := false
- if r.R.Buffered() < 1 {
- // ReadByte will flush the buffer; make a copy of the slice.
- copied = true
- line = append([]byte(nil), line...)
- }
-
- // Look for a continuation line.
- c, err := r.R.ReadByte()
- if err != nil {
- // Delay err until we read the byte next time.
- return line, nil
- }
- if c != ' ' && c != '\t' {
- // Not a continuation.
- r.R.UnreadByte()
- return line, nil
- }
-
- if !copied {
- // The next readLineSlice will invalidate the previous one.
- line = append(make([]byte, 0, len(line)*2), line...)
- }
+ // ReadByte or the next readLineSlice will flush the read buffer;
+ // copy the slice into buf.
+ r.buf = append(r.buf[:0], trim(line)...)
// Read continuation lines.
- for {
- // Consume leading spaces; one already gone.
- for {
- c, err = r.R.ReadByte()
- if err != nil {
- break
- }
- if c != ' ' && c != '\t' {
- r.R.UnreadByte()
- break
- }
- }
- var cont []byte
- cont, err = r.readLineSlice()
- cont = trim(cont)
- line = append(line, ' ')
- line = append(line, cont...)
+ for r.skipSpace() > 0 {
+ line, err := r.readLineSlice()
if err != nil {
break
}
+ r.buf = append(r.buf, ' ')
+ r.buf = append(r.buf, line...)
+ }
+ return r.buf, nil
+}
- // Check for leading space on next line.
- if c, err = r.R.ReadByte(); err != nil {
+// skipSpace skips R over all spaces and returns the number of bytes skipped.
+func (r *Reader) skipSpace() int {
+ n := 0
+ for {
+ c, err := r.R.ReadByte()
+ if err != nil {
+ // Bufio will keep err until next read.
break
}
if c != ' ' && c != '\t' {
r.R.UnreadByte()
break
}
+ n++
}
-
- // Delay error until next call.
- if len(line) > 0 {
- err = nil
- }
- return line, err
+ return n
}
-func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err os.Error) {
+func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
line, err := r.ReadLine()
if err != nil {
return
@@ -186,7 +170,7 @@ func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message
return parseCodeLine(line, expectCode)
}
-func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err os.Error) {
+func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
err = ProtocolError("short response: " + line)
return
@@ -221,7 +205,7 @@ func parseCodeLine(line string, expectCode int) (code int, continued bool, messa
//
// An expectCode <= 0 disables the check of the status code.
//
-func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err os.Error) {
+func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
code, continued, message, err := r.readCodeLine(expectCode)
if err == nil && continued {
err = ProtocolError("unexpected multi-line response: " + message)
@@ -251,12 +235,12 @@ func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err os.
//
// An expectCode <= 0 disables the check of the status code.
//
-func (r *Reader) ReadResponse(expectCode int) (code int, message string, err os.Error) {
+func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
code, continued, message, err := r.readCodeLine(expectCode)
for err == nil && continued {
line, err := r.ReadLine()
if err != nil {
- return
+ return 0, "", err
}
var code2 int
@@ -286,7 +270,7 @@ func (r *Reader) ReadResponse(expectCode int) (code int, message string, err os.
//
// The decoded form returned by the Reader's Read method
// rewrites the "\r\n" line endings into the simpler "\n",
-// removes leading dot escapes if present, and stops with error os.EOF
+// removes leading dot escapes if present, and stops with error io.EOF
// after consuming (and discarding) the end-of-sequence line.
func (r *Reader) DotReader() io.Reader {
r.closeDot()
@@ -300,7 +284,7 @@ type dotReader struct {
}
// Read satisfies reads by decoding dot-encoded data read from d.r.
-func (d *dotReader) Read(b []byte) (n int, err os.Error) {
+func (d *dotReader) Read(b []byte) (n int, err error) {
// Run data through a simple state machine to
// elide leading dots, rewrite trailing \r\n into \n,
// and detect ending .\r\n line.
@@ -317,7 +301,7 @@ func (d *dotReader) Read(b []byte) (n int, err os.Error) {
var c byte
c, err = br.ReadByte()
if err != nil {
- if err == os.EOF {
+ if err == io.EOF {
err = io.ErrUnexpectedEOF
}
break
@@ -379,7 +363,7 @@ func (d *dotReader) Read(b []byte) (n int, err os.Error) {
n++
}
if err == nil && d.state == stateEOF {
- err = os.EOF
+ err = io.EOF
}
if err != nil && d.r.dot == d {
d.r.dot = nil
@@ -404,7 +388,7 @@ func (r *Reader) closeDot() {
// ReadDotBytes reads a dot-encoding and returns the decoded data.
//
// See the documentation for the DotReader method for details about dot-encoding.
-func (r *Reader) ReadDotBytes() ([]byte, os.Error) {
+func (r *Reader) ReadDotBytes() ([]byte, error) {
return ioutil.ReadAll(r.DotReader())
}
@@ -412,17 +396,17 @@ func (r *Reader) ReadDotBytes() ([]byte, os.Error) {
// containing the decoded lines, with the final \r\n or \n elided from each.
//
// See the documentation for the DotReader method for details about dot-encoding.
-func (r *Reader) ReadDotLines() ([]string, os.Error) {
+func (r *Reader) ReadDotLines() ([]string, error) {
// We could use ReadDotBytes and then Split it,
// but reading a line at a time avoids needing a
// large contiguous block of memory and is simpler.
var v []string
- var err os.Error
+ var err error
for {
var line string
line, err = r.ReadLine()
if err != nil {
- if err == os.EOF {
+ if err == io.EOF {
err = io.ErrUnexpectedEOF
}
break
@@ -460,7 +444,7 @@ func (r *Reader) ReadDotLines() ([]string, os.Error) {
// "Long-Key": {"Even Longer Value"},
// }
//
-func (r *Reader) ReadMIMEHeader() (MIMEHeader, os.Error) {
+func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
m := make(MIMEHeader)
for {
kv, err := r.readContinuedLineSlice()
diff --git a/src/pkg/net/textproto/reader_test.go b/src/pkg/net/textproto/reader_test.go
index 23ebc3f61..4d0369148 100644
--- a/src/pkg/net/textproto/reader_test.go
+++ b/src/pkg/net/textproto/reader_test.go
@@ -7,7 +7,6 @@ package textproto
import (
"bufio"
"io"
- "os"
"reflect"
"strings"
"testing"
@@ -49,7 +48,7 @@ func TestReadLine(t *testing.T) {
t.Fatalf("Line 2: %s, %v", s, err)
}
s, err = r.ReadLine()
- if s != "" || err != os.EOF {
+ if s != "" || err != io.EOF {
t.Fatalf("EOF: %s, %v", s, err)
}
}
@@ -69,7 +68,7 @@ func TestReadContinuedLine(t *testing.T) {
t.Fatalf("Line 3: %s, %v", s, err)
}
s, err = r.ReadContinuedLine()
- if s != "" || err != os.EOF {
+ if s != "" || err != io.EOF {
t.Fatalf("EOF: %s, %v", s, err)
}
}
@@ -92,7 +91,7 @@ func TestReadCodeLine(t *testing.T) {
t.Fatalf("Line 3: wrong error %v\n", err)
}
code, msg, err = r.ReadCodeLine(1)
- if code != 0 || msg != "" || err != os.EOF {
+ if code != 0 || msg != "" || err != io.EOF {
t.Fatalf("EOF: %d, %s, %v", code, msg, err)
}
}
@@ -139,6 +138,32 @@ func TestReadMIMEHeader(t *testing.T) {
}
}
+func TestReadMIMEHeaderSingle(t *testing.T) {
+ r := reader("Foo: bar\n\n")
+ m, err := r.ReadMIMEHeader()
+ want := MIMEHeader{"Foo": {"bar"}}
+ if !reflect.DeepEqual(m, want) || err != nil {
+ t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want)
+ }
+}
+
+func TestLargeReadMIMEHeader(t *testing.T) {
+ data := make([]byte, 16*1024)
+ for i := 0; i < len(data); i++ {
+ data[i] = 'x'
+ }
+ sdata := string(data)
+ r := reader("Cookie: " + sdata + "\r\n\n")
+ m, err := r.ReadMIMEHeader()
+ if err != nil {
+ t.Fatalf("ReadMIMEHeader: %v", err)
+ }
+ cookie := m.Get("Cookie")
+ if cookie != sdata {
+ t.Fatalf("ReadMIMEHeader: %v bytes, want %v bytes", len(cookie), len(sdata))
+ }
+}
+
type readResponseTest struct {
in string
inCode int
@@ -187,7 +212,7 @@ func TestRFC959Lines(t *testing.T) {
t.Errorf("#%d: code=%d, want %d", i, code, tt.wantCode)
}
if msg != tt.wantMsg {
- t.Errorf("%#d: msg=%q, want %q", i, msg, tt.wantMsg)
+ t.Errorf("#%d: msg=%q, want %q", i, msg, tt.wantMsg)
}
}
}
diff --git a/src/pkg/net/textproto/textproto.go b/src/pkg/net/textproto/textproto.go
index 9f19b5495..317ec72b0 100644
--- a/src/pkg/net/textproto/textproto.go
+++ b/src/pkg/net/textproto/textproto.go
@@ -27,7 +27,6 @@ import (
"fmt"
"io"
"net"
- "os"
)
// An Error represents a numeric error response from a server.
@@ -36,7 +35,7 @@ type Error struct {
Msg string
}
-func (e *Error) String() string {
+func (e *Error) Error() string {
return fmt.Sprintf("%03d %s", e.Code, e.Msg)
}
@@ -44,7 +43,7 @@ func (e *Error) String() string {
// as an invalid response or a hung-up connection.
type ProtocolError string
-func (p ProtocolError) String() string {
+func (p ProtocolError) Error() string {
return string(p)
}
@@ -70,13 +69,13 @@ func NewConn(conn io.ReadWriteCloser) *Conn {
}
// Close closes the connection.
-func (c *Conn) Close() os.Error {
+func (c *Conn) Close() error {
return c.conn.Close()
}
// Dial connects to the given address on the given network using net.Dial
// and then returns a new Conn for the connection.
-func Dial(network, addr string) (*Conn, os.Error) {
+func Dial(network, addr string) (*Conn, error) {
c, err := net.Dial(network, addr)
if err != nil {
return nil, err
@@ -109,7 +108,7 @@ func Dial(network, addr string) (*Conn, os.Error) {
// }
// return c.ReadCodeLine(250)
//
-func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err os.Error) {
+func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err error) {
id = c.Next()
c.StartRequest(id)
err = c.PrintfLine(format, args...)
diff --git a/src/pkg/net/textproto/writer.go b/src/pkg/net/textproto/writer.go
index 4e705f6c3..03e2fd658 100644
--- a/src/pkg/net/textproto/writer.go
+++ b/src/pkg/net/textproto/writer.go
@@ -8,7 +8,6 @@ import (
"bufio"
"fmt"
"io"
- "os"
)
// A Writer implements convenience methods for writing
@@ -27,7 +26,7 @@ var crnl = []byte{'\r', '\n'}
var dotcrnl = []byte{'.', '\r', '\n'}
// PrintfLine writes the formatted output followed by \r\n.
-func (w *Writer) PrintfLine(format string, args ...interface{}) os.Error {
+func (w *Writer) PrintfLine(format string, args ...interface{}) error {
w.closeDot()
fmt.Fprintf(w.W, format, args...)
w.W.Write(crnl)
@@ -64,7 +63,7 @@ const (
wstateData // writing data in middle of line
)
-func (d *dotWriter) Write(b []byte) (n int, err os.Error) {
+func (d *dotWriter) Write(b []byte) (n int, err error) {
bw := d.w.W
for n < len(b) {
c := b[n]
@@ -100,7 +99,7 @@ func (d *dotWriter) Write(b []byte) (n int, err os.Error) {
return
}
-func (d *dotWriter) Close() os.Error {
+func (d *dotWriter) Close() error {
if d.w.dot == d {
d.w.dot = nil
}
diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go
index 0dbab5846..bae37c86b 100644
--- a/src/pkg/net/timeout_test.go
+++ b/src/pkg/net/timeout_test.go
@@ -5,7 +5,8 @@
package net
import (
- "os"
+ "fmt"
+ "runtime"
"testing"
"time"
)
@@ -17,35 +18,56 @@ func testTimeout(t *testing.T, network, addr string, readFrom bool) {
return
}
defer fd.Close()
- t0 := time.Nanoseconds()
- fd.SetReadTimeout(1e8) // 100ms
- var b [100]byte
- var n int
- var err1 os.Error
- if readFrom {
- n, _, err1 = fd.(PacketConn).ReadFrom(b[0:])
- } else {
- n, err1 = fd.Read(b[0:])
- }
- t1 := time.Nanoseconds()
what := "Read"
if readFrom {
what = "ReadFrom"
}
- if n != 0 || err1 == nil || !err1.(Error).Timeout() {
- t.Errorf("fd.%s on %s %s did not return 0, timeout: %v, %v", what, network, addr, n, err1)
- }
- if t1-t0 < 0.5e8 || t1-t0 > 1.5e8 {
- t.Errorf("fd.%s on %s %s took %f seconds, expected 0.1", what, network, addr, float64(t1-t0)/1e9)
+
+ errc := make(chan error, 1)
+ go func() {
+ t0 := time.Now()
+ fd.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ var b [100]byte
+ var n int
+ var err1 error
+ if readFrom {
+ n, _, err1 = fd.(PacketConn).ReadFrom(b[0:])
+ } else {
+ n, err1 = fd.Read(b[0:])
+ }
+ t1 := time.Now()
+ if n != 0 || err1 == nil || !err1.(Error).Timeout() {
+ errc <- fmt.Errorf("fd.%s on %s %s did not return 0, timeout: %v, %v", what, network, addr, n, err1)
+ return
+ }
+ if dt := t1.Sub(t0); dt < 50*time.Millisecond || dt > 250*time.Millisecond {
+ errc <- fmt.Errorf("fd.%s on %s %s took %s, expected 0.1s", what, network, addr, dt)
+ return
+ }
+ errc <- nil
+ }()
+ select {
+ case err := <-errc:
+ if err != nil {
+ t.Error(err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Errorf("%s on %s %s took over 1 second, expected 0.1s", what, network, addr)
}
}
func TestTimeoutUDP(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ return
+ }
testTimeout(t, "udp", "127.0.0.1:53", false)
testTimeout(t, "udp", "127.0.0.1:53", true)
}
func TestTimeoutTCP(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ return
+ }
// set up a listener that won't talk back
listening := make(chan string)
done := make(chan int)
@@ -55,3 +77,30 @@ func TestTimeoutTCP(t *testing.T) {
testTimeout(t, "tcp", addr, false)
<-done
}
+
+func TestDeadlineReset(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ return
+ }
+ ln, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ tl := ln.(*TCPListener)
+ tl.SetDeadline(time.Now().Add(1 * time.Minute))
+ tl.SetDeadline(time.Time{}) // reset it
+ errc := make(chan error, 1)
+ go func() {
+ _, err := ln.Accept()
+ errc <- err
+ }()
+ select {
+ case <-time.After(50 * time.Millisecond):
+ // Pass.
+ case err := <-errc:
+ // Accept should never return; we never
+ // connected to it.
+ t.Errorf("unexpected return from Accept; err=%v", err)
+ }
+}
diff --git a/src/pkg/net/udp_test.go b/src/pkg/net/udp_test.go
new file mode 100644
index 000000000..6ba762b1f
--- /dev/null
+++ b/src/pkg/net/udp_test.go
@@ -0,0 +1,87 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "runtime"
+ "testing"
+)
+
+func TestWriteToUDP(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ return
+ }
+
+ l, err := ListenPacket("udp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+ defer l.Close()
+
+ testWriteToConn(t, l.LocalAddr().String())
+ testWriteToPacketConn(t, l.LocalAddr().String())
+}
+
+func testWriteToConn(t *testing.T, raddr string) {
+ c, err := Dial("udp", raddr)
+ if err != nil {
+ t.Fatalf("Dial failed: %v", err)
+ }
+ defer c.Close()
+
+ ra, err := ResolveUDPAddr("udp", raddr)
+ if err != nil {
+ t.Fatalf("ResolveUDPAddr failed: %v", err)
+ }
+
+ _, err = c.(*UDPConn).WriteToUDP([]byte("Connection-oriented mode socket"), ra)
+ if err == nil {
+ t.Fatal("WriteToUDP should be failed")
+ }
+ if err != nil && err.(*OpError).Err != ErrWriteToConnected {
+ t.Fatalf("WriteToUDP should be failed as ErrWriteToConnected: %v", err)
+ }
+
+ _, err = c.(*UDPConn).WriteTo([]byte("Connection-oriented mode socket"), ra)
+ if err == nil {
+ t.Fatal("WriteTo should be failed")
+ }
+ if err != nil && err.(*OpError).Err != ErrWriteToConnected {
+ t.Fatalf("WriteTo should be failed as ErrWriteToConnected: %v", err)
+ }
+
+ _, err = c.Write([]byte("Connection-oriented mode socket"))
+ if err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+}
+
+func testWriteToPacketConn(t *testing.T, raddr string) {
+ c, err := ListenPacket("udp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("ListenPacket failed: %v", err)
+ }
+ defer c.Close()
+
+ ra, err := ResolveUDPAddr("udp", raddr)
+ if err != nil {
+ t.Fatalf("ResolveUDPAddr failed: %v", err)
+ }
+
+ _, err = c.(*UDPConn).WriteToUDP([]byte("Connection-less mode socket"), ra)
+ if err != nil {
+ t.Fatalf("WriteToUDP failed: %v", err)
+ }
+
+ _, err = c.WriteTo([]byte("Connection-less mode socket"), ra)
+ if err != nil {
+ t.Fatalf("WriteTo failed: %v", err)
+ }
+
+ _, err = c.(*UDPConn).Write([]byte("Connection-less mode socket"))
+ if err == nil {
+ t.Fatal("Write should be failed")
+ }
+}
diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go
index 3dfa71675..b3520cf09 100644
--- a/src/pkg/net/udpsock.go
+++ b/src/pkg/net/udpsock.go
@@ -6,10 +6,6 @@
package net
-import (
- "os"
-)
-
// UDPAddr represents the address of a UDP end point.
type UDPAddr struct {
IP IP
@@ -31,7 +27,7 @@ func (a *UDPAddr) String() string {
// numeric addresses on the network net, which must be "udp",
// "udp4" or "udp6". A literal IPv6 host address must be
// enclosed in square brackets, as in "[::]:80".
-func ResolveUDPAddr(net, addr string) (*UDPAddr, os.Error) {
+func ResolveUDPAddr(net, addr string) (*UDPAddr, error) {
ip, port, err := hostPortToIP(net, addr)
if err != nil {
return nil, err
diff --git a/src/pkg/net/udpsock_plan9.go b/src/pkg/net/udpsock_plan9.go
index d5c6ccb90..573438f85 100644
--- a/src/pkg/net/udpsock_plan9.go
+++ b/src/pkg/net/udpsock_plan9.go
@@ -7,7 +7,9 @@
package net
import (
+ "errors"
"os"
+ "time"
)
// UDPConn is the implementation of the Conn and PacketConn
@@ -16,6 +18,21 @@ type UDPConn struct {
plan9Conn
}
+// SetDeadline implements the net.Conn SetDeadline method.
+func (c *UDPConn) SetDeadline(t time.Time) error {
+ return os.EPLAN9
+}
+
+// SetReadDeadline implements the net.Conn SetReadDeadline method.
+func (c *UDPConn) SetReadDeadline(t time.Time) error {
+ return os.EPLAN9
+}
+
+// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
+func (c *UDPConn) SetWriteDeadline(t time.Time) error {
+ return os.EPLAN9
+}
+
// UDP-specific methods.
// ReadFromUDP reads a UDP packet from c, copying the payload into b.
@@ -23,8 +40,8 @@ type UDPConn struct {
// that was on the packet.
//
// ReadFromUDP can be made to time out and return an error with Timeout() == true
-// after a fixed time limit; see SetTimeout and SetReadTimeout.
-func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) {
+// after a fixed time limit; see SetDeadline and SetReadDeadline.
+func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) {
if !c.ok() {
return 0, nil, os.EINVAL
}
@@ -40,7 +57,7 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) {
return
}
if m < udpHeaderSize {
- return 0, nil, os.NewError("short read reading UDP header")
+ return 0, nil, errors.New("short read reading UDP header")
}
buf = buf[:m]
@@ -50,7 +67,7 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) {
}
// ReadFrom implements the net.PacketConn ReadFrom method.
-func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
+func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err error) {
if !c.ok() {
return 0, nil, os.EINVAL
}
@@ -61,9 +78,9 @@ func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
//
// WriteToUDP can be made to time out and return
// an error with Timeout() == true after a fixed time limit;
-// see SetTimeout and SetWriteTimeout.
+// see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
-func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err os.Error) {
+func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -87,13 +104,13 @@ func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err os.Error) {
}
// WriteTo implements the net.PacketConn WriteTo method.
-func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
+func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
a, ok := addr.(*UDPAddr)
if !ok {
- return 0, &OpError{"writeto", "udp", addr, os.EINVAL}
+ return 0, &OpError{"write", c.dir, addr, os.EINVAL}
}
return c.WriteToUDP(b, a)
}
@@ -101,14 +118,14 @@ func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
// DialUDP connects to the remote address raddr on the network net,
// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used
// as the local address for the connection.
-func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err os.Error) {
+func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err error) {
switch net {
case "udp", "udp4", "udp6":
default:
return nil, UnknownNetworkError(net)
}
if raddr == nil {
- return nil, &OpError{"dial", "udp", nil, errMissingAddress}
+ return nil, &OpError{"dial", net, nil, errMissingAddress}
}
c1, err := dialPlan9(net, laddr, raddr)
if err != nil {
@@ -149,14 +166,14 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) {
// local address laddr. The returned connection c's ReadFrom
// and WriteTo methods can be used to receive and send UDP
// packets with per-packet addressing.
-func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err os.Error) {
+func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err error) {
switch net {
case "udp", "udp4", "udp6":
default:
return nil, UnknownNetworkError(net)
}
if laddr == nil {
- return nil, &OpError{"listen", "udp", nil, errMissingAddress}
+ return nil, &OpError{"listen", net, nil, errMissingAddress}
}
l, err := listenPlan9(net, laddr)
if err != nil {
@@ -172,7 +189,7 @@ func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err os.Error) {
// JoinGroup joins the IP multicast group named by addr on ifi,
// which specifies the interface to join. JoinGroup uses the
// default multicast interface if ifi is nil.
-func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) os.Error {
+func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) error {
if !c.ok() {
return os.EINVAL
}
@@ -180,7 +197,7 @@ func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) os.Error {
}
// LeaveGroup exits the IP multicast group named by addr on ifi.
-func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) os.Error {
+func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) error {
if !c.ok() {
return os.EINVAL
}
diff --git a/src/pkg/net/udpsock_posix.go b/src/pkg/net/udpsock_posix.go
index 06298ee40..fa3d29adf 100644
--- a/src/pkg/net/udpsock_posix.go
+++ b/src/pkg/net/udpsock_posix.go
@@ -2,18 +2,21 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd windows
+// +build darwin freebsd linux netbsd openbsd windows
// UDP sockets
package net
import (
- "bytes"
+ "errors"
"os"
"syscall"
+ "time"
)
+var ErrWriteToConnected = errors.New("use of WriteTo with pre-connected UDP")
+
func sockaddrToUDP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
@@ -34,7 +37,7 @@ func (a *UDPAddr) family() int {
return syscall.AF_INET6
}
-func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) {
+func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
return ipToSockaddr(family, a.IP, a.Port)
}
@@ -58,7 +61,7 @@ func (c *UDPConn) ok() bool { return c != nil && c.fd != nil }
// Implementation of the Conn interface - see Conn for documentation.
// Read implements the net.Conn Read method.
-func (c *UDPConn) Read(b []byte) (n int, err os.Error) {
+func (c *UDPConn) Read(b []byte) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -66,7 +69,7 @@ func (c *UDPConn) Read(b []byte) (n int, err os.Error) {
}
// Write implements the net.Conn Write method.
-func (c *UDPConn) Write(b []byte) (n int, err os.Error) {
+func (c *UDPConn) Write(b []byte) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -74,7 +77,7 @@ func (c *UDPConn) Write(b []byte) (n int, err os.Error) {
}
// Close closes the UDP connection.
-func (c *UDPConn) Close() os.Error {
+func (c *UDPConn) Close() error {
if !c.ok() {
return os.EINVAL
}
@@ -99,33 +102,33 @@ func (c *UDPConn) RemoteAddr() Addr {
return c.fd.raddr
}
-// SetTimeout implements the net.Conn SetTimeout method.
-func (c *UDPConn) SetTimeout(nsec int64) os.Error {
+// SetDeadline implements the net.Conn SetDeadline method.
+func (c *UDPConn) SetDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setTimeout(c.fd, nsec)
+ return setDeadline(c.fd, t)
}
-// SetReadTimeout implements the net.Conn SetReadTimeout method.
-func (c *UDPConn) SetReadTimeout(nsec int64) os.Error {
+// SetReadDeadline implements the net.Conn SetReadDeadline method.
+func (c *UDPConn) SetReadDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setReadTimeout(c.fd, nsec)
+ return setReadDeadline(c.fd, t)
}
-// SetWriteTimeout implements the net.Conn SetWriteTimeout method.
-func (c *UDPConn) SetWriteTimeout(nsec int64) os.Error {
+// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
+func (c *UDPConn) SetWriteDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setWriteTimeout(c.fd, nsec)
+ return setWriteDeadline(c.fd, t)
}
// SetReadBuffer sets the size of the operating system's
// receive buffer associated with the connection.
-func (c *UDPConn) SetReadBuffer(bytes int) os.Error {
+func (c *UDPConn) SetReadBuffer(bytes int) error {
if !c.ok() {
return os.EINVAL
}
@@ -134,7 +137,7 @@ func (c *UDPConn) SetReadBuffer(bytes int) os.Error {
// SetWriteBuffer sets the size of the operating system's
// transmit buffer associated with the connection.
-func (c *UDPConn) SetWriteBuffer(bytes int) os.Error {
+func (c *UDPConn) SetWriteBuffer(bytes int) error {
if !c.ok() {
return os.EINVAL
}
@@ -148,8 +151,8 @@ func (c *UDPConn) SetWriteBuffer(bytes int) os.Error {
// that was on the packet.
//
// ReadFromUDP can be made to time out and return an error with Timeout() == true
-// after a fixed time limit; see SetTimeout and SetReadTimeout.
-func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) {
+// after a fixed time limit; see SetDeadline and SetReadDeadline.
+func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) {
if !c.ok() {
return 0, nil, os.EINVAL
}
@@ -164,7 +167,7 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) {
}
// ReadFrom implements the net.PacketConn ReadFrom method.
-func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
+func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err error) {
if !c.ok() {
return 0, nil, os.EINVAL
}
@@ -176,27 +179,30 @@ func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
//
// WriteToUDP can be made to time out and return
// an error with Timeout() == true after a fixed time limit;
-// see SetTimeout and SetWriteTimeout.
+// see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
-func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err os.Error) {
+func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) {
if !c.ok() {
return 0, os.EINVAL
}
- sa, err1 := addr.sockaddr(c.fd.family)
- if err1 != nil {
- return 0, &OpError{Op: "write", Net: "udp", Addr: addr, Error: err1}
+ if c.fd.isConnected {
+ return 0, &OpError{"write", c.fd.net, addr, ErrWriteToConnected}
+ }
+ sa, err := addr.sockaddr(c.fd.family)
+ if err != nil {
+ return 0, &OpError{"write", c.fd.net, addr, err}
}
return c.fd.WriteTo(b, sa)
}
// WriteTo implements the net.PacketConn WriteTo method.
-func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
+func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) {
if !c.ok() {
return 0, os.EINVAL
}
a, ok := addr.(*UDPAddr)
if !ok {
- return 0, &OpError{"writeto", "udp", addr, os.EINVAL}
+ return 0, &OpError{"write", c.fd.net, addr, os.EINVAL}
}
return c.WriteToUDP(b, a)
}
@@ -204,14 +210,14 @@ func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
// DialUDP connects to the remote address raddr on the network net,
// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used
// as the local address for the connection.
-func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err os.Error) {
+func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err error) {
switch net {
case "udp", "udp4", "udp6":
default:
return nil, UnknownNetworkError(net)
}
if raddr == nil {
- return nil, &OpError{"dial", "udp", nil, errMissingAddress}
+ return nil, &OpError{"dial", net, nil, errMissingAddress}
}
fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP)
if e != nil {
@@ -224,44 +230,35 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err os.Error) {
// local address laddr. The returned connection c's ReadFrom
// and WriteTo methods can be used to receive and send UDP
// packets with per-packet addressing.
-func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err os.Error) {
+func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) {
switch net {
case "udp", "udp4", "udp6":
default:
return nil, UnknownNetworkError(net)
}
if laddr == nil {
- return nil, &OpError{"listen", "udp", nil, errMissingAddress}
+ return nil, &OpError{"listen", net, nil, errMissingAddress}
}
- fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP)
- if e != nil {
- return nil, e
+ fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP)
+ if err != nil {
+ return nil, err
}
return newUDPConn(fd), nil
}
-// BindToDevice binds a UDPConn to a network interface.
-func (c *UDPConn) BindToDevice(device string) os.Error {
- if !c.ok() {
- return os.EINVAL
- }
- c.fd.incref()
- defer c.fd.decref()
- return os.NewSyscallError("setsockopt", syscall.BindToDevice(c.fd.sysfd, device))
-}
-
// File returns a copy of the underlying os.File, set to blocking mode.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
-func (c *UDPConn) File() (f *os.File, err os.Error) { return c.fd.dup() }
+func (c *UDPConn) File() (f *os.File, err error) { return c.fd.dup() }
// JoinGroup joins the IP multicast group named by addr on ifi,
// which specifies the interface to join. JoinGroup uses the
// default multicast interface if ifi is nil.
-func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) os.Error {
+func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) error {
if !c.ok() {
return os.EINVAL
}
+ setDefaultMulticastSockopts(c.fd)
ip := addr.To4()
if ip != nil {
return joinIPv4GroupUDP(c, ifi, ip)
@@ -270,7 +267,7 @@ func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) os.Error {
}
// LeaveGroup exits the IP multicast group named by addr on ifi.
-func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) os.Error {
+func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) error {
if !c.ok() {
return os.EINVAL
}
@@ -281,68 +278,34 @@ func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) os.Error {
return leaveIPv6GroupUDP(c, ifi, addr)
}
-func joinIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) os.Error {
- mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
- if err := setIPv4InterfaceToJoin(mreq, ifi); err != nil {
- return &OpError{"joinipv4group", "udp", &IPAddr{ip}, err}
- }
- if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(c.fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)); err != nil {
- return &OpError{"joinipv4group", "udp", &IPAddr{ip}, err}
- }
- return nil
-}
-
-func leaveIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) os.Error {
- mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
- if err := setIPv4InterfaceToJoin(mreq, ifi); err != nil {
- return &OpError{"leaveipv4group", "udp", &IPAddr{ip}, err}
- }
- if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(c.fd.sysfd, syscall.IPPROTO_IP, syscall.IP_DROP_MEMBERSHIP, mreq)); err != nil {
- return &OpError{"leaveipv4group", "udp", &IPAddr{ip}, err}
+func joinIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
+ err := joinIPv4Group(c.fd, ifi, ip)
+ if err != nil {
+ return &OpError{"joinipv4group", c.fd.net, &IPAddr{ip}, err}
}
return nil
}
-func setIPv4InterfaceToJoin(mreq *syscall.IPMreq, ifi *Interface) os.Error {
- if ifi == nil {
- return nil
- }
- ifat, err := ifi.Addrs()
+func leaveIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
+ err := leaveIPv4Group(c.fd, ifi, ip)
if err != nil {
- return err
- }
- for _, ifa := range ifat {
- if x := ifa.(*IPAddr).IP.To4(); x != nil {
- copy(mreq.Interface[:], x)
- break
- }
- }
- if bytes.Equal(mreq.Multiaddr[:], IPv4zero) {
- return os.EINVAL
+ return &OpError{"leaveipv4group", c.fd.net, &IPAddr{ip}, err}
}
return nil
}
-func joinIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) os.Error {
- mreq := &syscall.IPv6Mreq{}
- copy(mreq.Multiaddr[:], ip)
- if ifi != nil {
- mreq.Interface = uint32(ifi.Index)
- }
- if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(c.fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)); err != nil {
- return &OpError{"joinipv6group", "udp", &IPAddr{ip}, err}
+func joinIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
+ err := joinIPv6Group(c.fd, ifi, ip)
+ if err != nil {
+ return &OpError{"joinipv6group", c.fd.net, &IPAddr{ip}, err}
}
return nil
}
-func leaveIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) os.Error {
- mreq := &syscall.IPv6Mreq{}
- copy(mreq.Multiaddr[:], ip)
- if ifi != nil {
- mreq.Interface = uint32(ifi.Index)
- }
- if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(c.fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_LEAVE_GROUP, mreq)); err != nil {
- return &OpError{"leaveipv6group", "udp", &IPAddr{ip}, err}
+func leaveIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
+ err := leaveIPv6Group(c.fd, ifi, ip)
+ if err != nil {
+ return &OpError{"leaveipv6group", c.fd.net, &IPAddr{ip}, err}
}
return nil
}
diff --git a/src/pkg/net/unicast_test.go b/src/pkg/net/unicast_test.go
new file mode 100644
index 000000000..297276d3a
--- /dev/null
+++ b/src/pkg/net/unicast_test.go
@@ -0,0 +1,111 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "io"
+ "runtime"
+ "testing"
+)
+
+var unicastTests = []struct {
+ net string
+ laddr string
+ ipv6 bool
+ packet bool
+}{
+ {net: "tcp4", laddr: "127.0.0.1:0"},
+ {net: "tcp4", laddr: "previous"},
+ {net: "tcp6", laddr: "[::1]:0", ipv6: true},
+ {net: "tcp6", laddr: "previous", ipv6: true},
+ {net: "udp4", laddr: "127.0.0.1:0", packet: true},
+ {net: "udp6", laddr: "[::1]:0", ipv6: true, packet: true},
+}
+
+func TestUnicastTCPAndUDP(t *testing.T) {
+ if runtime.GOOS == "plan9" || runtime.GOOS == "windows" {
+ return
+ }
+
+ prevladdr := ""
+ for _, tt := range unicastTests {
+ if tt.ipv6 && !supportsIPv6 {
+ continue
+ }
+ var (
+ fd *netFD
+ closer io.Closer
+ )
+ if !tt.packet {
+ if tt.laddr == "previous" {
+ tt.laddr = prevladdr
+ }
+ l, err := Listen(tt.net, tt.laddr)
+ if err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+ prevladdr = l.Addr().String()
+ closer = l
+ fd = l.(*TCPListener).fd
+ } else {
+ c, err := ListenPacket(tt.net, tt.laddr)
+ if err != nil {
+ t.Fatalf("ListenPacket failed: %v", err)
+ }
+ closer = c
+ fd = c.(*UDPConn).fd
+ }
+ if !tt.ipv6 {
+ testIPv4UnicastSocketOptions(t, fd)
+ } else {
+ testIPv6UnicastSocketOptions(t, fd)
+ }
+ closer.Close()
+ }
+}
+
+func testIPv4UnicastSocketOptions(t *testing.T, fd *netFD) {
+ tos, err := ipv4TOS(fd)
+ if err != nil {
+ t.Fatalf("ipv4TOS failed: %v", err)
+ }
+ t.Logf("IPv4 TOS: %v", tos)
+ err = setIPv4TOS(fd, 1)
+ if err != nil {
+ t.Fatalf("setIPv4TOS failed: %v", err)
+ }
+
+ ttl, err := ipv4TTL(fd)
+ if err != nil {
+ t.Fatalf("ipv4TTL failed: %v", err)
+ }
+ t.Logf("IPv4 TTL: %v", ttl)
+ err = setIPv4TTL(fd, 1)
+ if err != nil {
+ t.Fatalf("setIPv4TTL failed: %v", err)
+ }
+}
+
+func testIPv6UnicastSocketOptions(t *testing.T, fd *netFD) {
+ tos, err := ipv6TrafficClass(fd)
+ if err != nil {
+ t.Fatalf("ipv6TrafficClass failed: %v", err)
+ }
+ t.Logf("IPv6 TrafficClass: %v", tos)
+ err = setIPv6TrafficClass(fd, 1)
+ if err != nil {
+ t.Fatalf("setIPv6TrafficClass failed: %v", err)
+ }
+
+ hoplim, err := ipv6HopLimit(fd)
+ if err != nil {
+ t.Fatalf("ipv6HopLimit failed: %v", err)
+ }
+ t.Logf("IPv6 HopLimit: %v", hoplim)
+ err = setIPv6HopLimit(fd, 1)
+ if err != nil {
+ t.Fatalf("setIPv6HopLimit failed: %v", err)
+ }
+}
diff --git a/src/pkg/net/unixsock.go b/src/pkg/net/unixsock.go
index d5040f9a2..ae0956958 100644
--- a/src/pkg/net/unixsock.go
+++ b/src/pkg/net/unixsock.go
@@ -6,10 +6,6 @@
package net
-import (
- "os"
-)
-
// UnixAddr represents the address of a Unix domain socket end point.
type UnixAddr struct {
Name string
@@ -38,7 +34,7 @@ func (a *UnixAddr) toAddr() Addr {
// ResolveUnixAddr parses addr as a Unix domain socket address.
// The string net gives the network name, "unix", "unixgram" or
// "unixpacket".
-func ResolveUnixAddr(net, addr string) (*UnixAddr, os.Error) {
+func ResolveUnixAddr(net, addr string) (*UnixAddr, error) {
switch net {
case "unix":
case "unixpacket":
diff --git a/src/pkg/net/unixsock_plan9.go b/src/pkg/net/unixsock_plan9.go
index 7e212df8a..e8087d09a 100644
--- a/src/pkg/net/unixsock_plan9.go
+++ b/src/pkg/net/unixsock_plan9.go
@@ -8,6 +8,7 @@ package net
import (
"os"
+ "time"
)
// UnixConn is an implementation of the Conn interface
@@ -17,17 +18,17 @@ type UnixConn bool
// Implementation of the Conn interface - see Conn for documentation.
// Read implements the net.Conn Read method.
-func (c *UnixConn) Read(b []byte) (n int, err os.Error) {
+func (c *UnixConn) Read(b []byte) (n int, err error) {
return 0, os.EPLAN9
}
// Write implements the net.Conn Write method.
-func (c *UnixConn) Write(b []byte) (n int, err os.Error) {
+func (c *UnixConn) Write(b []byte) (n int, err error) {
return 0, os.EPLAN9
}
// Close closes the Unix domain connection.
-func (c *UnixConn) Close() os.Error {
+func (c *UnixConn) Close() error {
return os.EPLAN9
}
@@ -44,29 +45,29 @@ func (c *UnixConn) RemoteAddr() Addr {
return nil
}
-// SetTimeout implements the net.Conn SetTimeout method.
-func (c *UnixConn) SetTimeout(nsec int64) os.Error {
+// SetDeadline implements the net.Conn SetDeadline method.
+func (c *UnixConn) SetDeadline(t time.Time) error {
return os.EPLAN9
}
-// SetReadTimeout implements the net.Conn SetReadTimeout method.
-func (c *UnixConn) SetReadTimeout(nsec int64) os.Error {
+// SetReadDeadline implements the net.Conn SetReadDeadline method.
+func (c *UnixConn) SetReadDeadline(t time.Time) error {
return os.EPLAN9
}
-// SetWriteTimeout implements the net.Conn SetWriteTimeout method.
-func (c *UnixConn) SetWriteTimeout(nsec int64) os.Error {
+// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
+func (c *UnixConn) SetWriteDeadline(t time.Time) error {
return os.EPLAN9
}
// ReadFrom implements the net.PacketConn ReadFrom method.
-func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
+func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err error) {
err = os.EPLAN9
return
}
// WriteTo implements the net.PacketConn WriteTo method.
-func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
+func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err error) {
err = os.EPLAN9
return
}
@@ -74,7 +75,7 @@ func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
// DialUnix connects to the remote address raddr on the network net,
// which must be "unix" or "unixgram". If laddr is not nil, it is used
// as the local address for the connection.
-func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err os.Error) {
+func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err error) {
return nil, os.EPLAN9
}
@@ -85,19 +86,19 @@ type UnixListener bool
// ListenUnix announces on the Unix domain socket laddr and returns a Unix listener.
// Net must be "unix" (stream sockets).
-func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) {
+func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err error) {
return nil, os.EPLAN9
}
// Accept implements the Accept method in the Listener interface;
// it waits for the next call and returns a generic Conn.
-func (l *UnixListener) Accept() (c Conn, err os.Error) {
+func (l *UnixListener) Accept() (c Conn, err error) {
return nil, os.EPLAN9
}
// Close stops listening on the Unix address.
// Already accepted connections are not closed.
-func (l *UnixListener) Close() os.Error {
+func (l *UnixListener) Close() error {
return os.EPLAN9
}
diff --git a/src/pkg/net/unixsock_posix.go b/src/pkg/net/unixsock_posix.go
index fccf0189c..e500ddb4e 100644
--- a/src/pkg/net/unixsock_posix.go
+++ b/src/pkg/net/unixsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux openbsd windows
+// +build darwin freebsd linux netbsd openbsd windows
// Unix domain sockets
@@ -11,19 +11,20 @@ package net
import (
"os"
"syscall"
+ "time"
)
-func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err os.Error) {
- var proto int
+func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err error) {
+ var sotype int
switch net {
default:
return nil, UnknownNetworkError(net)
case "unix":
- proto = syscall.SOCK_STREAM
+ sotype = syscall.SOCK_STREAM
case "unixgram":
- proto = syscall.SOCK_DGRAM
+ sotype = syscall.SOCK_DGRAM
case "unixpacket":
- proto = syscall.SOCK_SEQPACKET
+ sotype = syscall.SOCK_SEQPACKET
}
var la, ra syscall.Sockaddr
@@ -37,8 +38,8 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err
}
if raddr != nil {
ra = &syscall.SockaddrUnix{Name: raddr.Name}
- } else if proto != syscall.SOCK_DGRAM || laddr == nil {
- return nil, &OpError{Op: mode, Net: net, Error: errMissingAddress}
+ } else if sotype != syscall.SOCK_DGRAM || laddr == nil {
+ return nil, &OpError{Op: mode, Net: net, Err: errMissingAddress}
}
case "listen":
@@ -47,18 +48,18 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err
}
la = &syscall.SockaddrUnix{Name: laddr.Name}
if raddr != nil {
- return nil, &OpError{Op: mode, Net: net, Addr: raddr, Error: &AddrError{Error: "unexpected remote address", Addr: raddr.String()}}
+ return nil, &OpError{Op: mode, Net: net, Addr: raddr, Err: &AddrError{Err: "unexpected remote address", Addr: raddr.String()}}
}
}
f := sockaddrToUnix
- if proto == syscall.SOCK_DGRAM {
+ if sotype == syscall.SOCK_DGRAM {
f = sockaddrToUnixgram
- } else if proto == syscall.SOCK_SEQPACKET {
+ } else if sotype == syscall.SOCK_SEQPACKET {
f = sockaddrToUnixpacket
}
- fd, oserr := socket(net, syscall.AF_UNIX, proto, 0, la, ra, f)
+ fd, oserr := socket(net, syscall.AF_UNIX, sotype, 0, la, ra, f)
if oserr != nil {
goto Error
}
@@ -69,7 +70,7 @@ Error:
if mode == "listen" {
addr = laddr
}
- return nil, &OpError{Op: mode, Net: net, Addr: addr, Error: oserr}
+ return nil, &OpError{Op: mode, Net: net, Addr: addr, Err: oserr}
}
func sockaddrToUnix(sa syscall.Sockaddr) Addr {
@@ -93,8 +94,8 @@ func sockaddrToUnixpacket(sa syscall.Sockaddr) Addr {
return nil
}
-func protoToNet(proto int) string {
- switch proto {
+func sotypeToNet(sotype int) string {
+ switch sotype {
case syscall.SOCK_STREAM:
return "unix"
case syscall.SOCK_SEQPACKET:
@@ -102,7 +103,7 @@ func protoToNet(proto int) string {
case syscall.SOCK_DGRAM:
return "unixgram"
default:
- panic("protoToNet unknown protocol")
+ panic("sotypeToNet unknown socket type")
}
return ""
}
@@ -120,7 +121,7 @@ func (c *UnixConn) ok() bool { return c != nil && c.fd != nil }
// Implementation of the Conn interface - see Conn for documentation.
// Read implements the net.Conn Read method.
-func (c *UnixConn) Read(b []byte) (n int, err os.Error) {
+func (c *UnixConn) Read(b []byte) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -128,7 +129,7 @@ func (c *UnixConn) Read(b []byte) (n int, err os.Error) {
}
// Write implements the net.Conn Write method.
-func (c *UnixConn) Write(b []byte) (n int, err os.Error) {
+func (c *UnixConn) Write(b []byte) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
@@ -136,7 +137,7 @@ func (c *UnixConn) Write(b []byte) (n int, err os.Error) {
}
// Close closes the Unix domain connection.
-func (c *UnixConn) Close() os.Error {
+func (c *UnixConn) Close() error {
if !c.ok() {
return os.EINVAL
}
@@ -164,33 +165,33 @@ func (c *UnixConn) RemoteAddr() Addr {
return c.fd.raddr
}
-// SetTimeout implements the net.Conn SetTimeout method.
-func (c *UnixConn) SetTimeout(nsec int64) os.Error {
+// SetDeadline implements the net.Conn SetDeadline method.
+func (c *UnixConn) SetDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setTimeout(c.fd, nsec)
+ return setDeadline(c.fd, t)
}
-// SetReadTimeout implements the net.Conn SetReadTimeout method.
-func (c *UnixConn) SetReadTimeout(nsec int64) os.Error {
+// SetReadDeadline implements the net.Conn SetReadDeadline method.
+func (c *UnixConn) SetReadDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setReadTimeout(c.fd, nsec)
+ return setReadDeadline(c.fd, t)
}
-// SetWriteTimeout implements the net.Conn SetWriteTimeout method.
-func (c *UnixConn) SetWriteTimeout(nsec int64) os.Error {
+// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
+func (c *UnixConn) SetWriteDeadline(t time.Time) error {
if !c.ok() {
return os.EINVAL
}
- return setWriteTimeout(c.fd, nsec)
+ return setWriteDeadline(c.fd, t)
}
// SetReadBuffer sets the size of the operating system's
// receive buffer associated with the connection.
-func (c *UnixConn) SetReadBuffer(bytes int) os.Error {
+func (c *UnixConn) SetReadBuffer(bytes int) error {
if !c.ok() {
return os.EINVAL
}
@@ -199,7 +200,7 @@ func (c *UnixConn) SetReadBuffer(bytes int) os.Error {
// SetWriteBuffer sets the size of the operating system's
// transmit buffer associated with the connection.
-func (c *UnixConn) SetWriteBuffer(bytes int) os.Error {
+func (c *UnixConn) SetWriteBuffer(bytes int) error {
if !c.ok() {
return os.EINVAL
}
@@ -212,21 +213,21 @@ func (c *UnixConn) SetWriteBuffer(bytes int) os.Error {
//
// ReadFromUnix can be made to time out and return
// an error with Timeout() == true after a fixed time limit;
-// see SetTimeout and SetReadTimeout.
-func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err os.Error) {
+// see SetDeadline and SetReadDeadline.
+func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err error) {
if !c.ok() {
return 0, nil, os.EINVAL
}
n, sa, err := c.fd.ReadFrom(b)
switch sa := sa.(type) {
case *syscall.SockaddrUnix:
- addr = &UnixAddr{sa.Name, protoToNet(c.fd.proto)}
+ addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)}
}
return
}
// ReadFrom implements the net.PacketConn ReadFrom method.
-func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
+func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err error) {
if !c.ok() {
return 0, nil, os.EINVAL
}
@@ -238,13 +239,13 @@ func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) {
//
// WriteToUnix can be made to time out and return
// an error with Timeout() == true after a fixed time limit;
-// see SetTimeout and SetWriteTimeout.
+// see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
-func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err os.Error) {
+func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
- if addr.Net != protoToNet(c.fd.proto) {
+ if addr.Net != sotypeToNet(c.fd.sotype) {
return 0, os.EAFNOSUPPORT
}
sa := &syscall.SockaddrUnix{Name: addr.Name}
@@ -252,35 +253,35 @@ func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err os.Error) {
}
// WriteTo implements the net.PacketConn WriteTo method.
-func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
+func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err error) {
if !c.ok() {
return 0, os.EINVAL
}
a, ok := addr.(*UnixAddr)
if !ok {
- return 0, &OpError{"writeto", "unix", addr, os.EINVAL}
+ return 0, &OpError{"write", c.fd.net, addr, os.EINVAL}
}
return c.WriteToUnix(b, a)
}
-func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err os.Error) {
+func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
if !c.ok() {
return 0, 0, 0, nil, os.EINVAL
}
n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob)
switch sa := sa.(type) {
case *syscall.SockaddrUnix:
- addr = &UnixAddr{sa.Name, protoToNet(c.fd.proto)}
+ addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)}
}
return
}
-func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err os.Error) {
+func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) {
if !c.ok() {
return 0, 0, os.EINVAL
}
if addr != nil {
- if addr.Net != protoToNet(c.fd.proto) {
+ if addr.Net != sotypeToNet(c.fd.sotype) {
return 0, 0, os.EAFNOSUPPORT
}
sa := &syscall.SockaddrUnix{Name: addr.Name}
@@ -292,12 +293,12 @@ func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err
// File returns a copy of the underlying os.File, set to blocking mode.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
-func (c *UnixConn) File() (f *os.File, err os.Error) { return c.fd.dup() }
+func (c *UnixConn) File() (f *os.File, err error) { return c.fd.dup() }
// DialUnix connects to the remote address raddr on the network net,
// which must be "unix" or "unixgram". If laddr is not nil, it is used
// as the local address for the connection.
-func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err os.Error) {
+func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err error) {
fd, e := unixSocket(net, laddr, raddr, "dial")
if e != nil {
return nil, e
@@ -315,7 +316,7 @@ type UnixListener struct {
// ListenUnix announces on the Unix domain socket laddr and returns a Unix listener.
// Net must be "unix" (stream sockets).
-func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) {
+func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) {
if net != "unix" && net != "unixgram" && net != "unixpacket" {
return nil, UnknownNetworkError(net)
}
@@ -326,17 +327,17 @@ func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) {
if err != nil {
return nil, err
}
- e1 := syscall.Listen(fd.sysfd, 8) // listenBacklog());
- if e1 != 0 {
+ err = syscall.Listen(fd.sysfd, listenerBacklog)
+ if err != nil {
closesocket(fd.sysfd)
- return nil, &OpError{Op: "listen", Net: "unix", Addr: laddr, Error: os.Errno(e1)}
+ return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: err}
}
return &UnixListener{fd, laddr.Name}, nil
}
// AcceptUnix accepts the next incoming call and returns the new connection
// and the remote address.
-func (l *UnixListener) AcceptUnix() (c *UnixConn, err os.Error) {
+func (l *UnixListener) AcceptUnix() (c *UnixConn, err error) {
if l == nil || l.fd == nil {
return nil, os.EINVAL
}
@@ -350,7 +351,7 @@ func (l *UnixListener) AcceptUnix() (c *UnixConn, err os.Error) {
// Accept implements the Accept method in the Listener interface;
// it waits for the next call and returns a generic Conn.
-func (l *UnixListener) Accept() (c Conn, err os.Error) {
+func (l *UnixListener) Accept() (c Conn, err error) {
c1, err := l.AcceptUnix()
if err != nil {
return nil, err
@@ -360,7 +361,7 @@ func (l *UnixListener) Accept() (c Conn, err os.Error) {
// Close stops listening on the Unix address.
// Already accepted connections are not closed.
-func (l *UnixListener) Close() os.Error {
+func (l *UnixListener) Close() error {
if l == nil || l.fd == nil {
return os.EINVAL
}
@@ -386,31 +387,32 @@ func (l *UnixListener) Close() os.Error {
// Addr returns the listener's network address.
func (l *UnixListener) Addr() Addr { return l.fd.laddr }
-// SetTimeout sets the deadline associated wuth the listener
-func (l *UnixListener) SetTimeout(nsec int64) (err os.Error) {
+// SetDeadline sets the deadline associated with the listener.
+// A zero time value disables the deadline.
+func (l *UnixListener) SetDeadline(t time.Time) (err error) {
if l == nil || l.fd == nil {
return os.EINVAL
}
- return setTimeout(l.fd, nsec)
+ return setDeadline(l.fd, t)
}
// File returns a copy of the underlying os.File, set to blocking mode.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
-func (l *UnixListener) File() (f *os.File, err os.Error) { return l.fd.dup() }
+func (l *UnixListener) File() (f *os.File, err error) { return l.fd.dup() }
// ListenUnixgram listens for incoming Unix datagram packets addressed to the
// local address laddr. The returned connection c's ReadFrom
// and WriteTo methods can be used to receive and send UDP
// packets with per-packet addressing. The network net must be "unixgram".
-func ListenUnixgram(net string, laddr *UnixAddr) (c *UDPConn, err os.Error) {
+func ListenUnixgram(net string, laddr *UnixAddr) (c *UDPConn, err error) {
switch net {
case "unixgram":
default:
return nil, UnknownNetworkError(net)
}
if laddr == nil {
- return nil, &OpError{"listen", "unixgram", nil, errMissingAddress}
+ return nil, &OpError{"listen", net, nil, errMissingAddress}
}
fd, e := unixSocket(net, laddr, nil, "listen")
if e != nil {
diff --git a/src/pkg/net/url/Makefile b/src/pkg/net/url/Makefile
new file mode 100644
index 000000000..bef0647a4
--- /dev/null
+++ b/src/pkg/net/url/Makefile
@@ -0,0 +1,11 @@
+# Copyright 2009 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../Make.inc
+
+TARG=net/url
+GOFILES=\
+ url.go\
+
+include ../../../Make.pkg
diff --git a/src/pkg/net/url/url.go b/src/pkg/net/url/url.go
new file mode 100644
index 000000000..0068e98af
--- /dev/null
+++ b/src/pkg/net/url/url.go
@@ -0,0 +1,664 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package URL parses URLs and implements query escaping.
+// See RFC 3986.
+package url
+
+import (
+ "errors"
+ "strconv"
+ "strings"
+)
+
+// Error reports an error and the operation and URL that caused it.
+type Error struct {
+ Op string
+ URL string
+ Err error
+}
+
+func (e *Error) Error() string { return e.Op + " " + e.URL + ": " + e.Err.Error() }
+
+func ishex(c byte) bool {
+ switch {
+ case '0' <= c && c <= '9':
+ return true
+ case 'a' <= c && c <= 'f':
+ return true
+ case 'A' <= c && c <= 'F':
+ return true
+ }
+ return false
+}
+
+func unhex(c byte) byte {
+ switch {
+ case '0' <= c && c <= '9':
+ return c - '0'
+ case 'a' <= c && c <= 'f':
+ return c - 'a' + 10
+ case 'A' <= c && c <= 'F':
+ return c - 'A' + 10
+ }
+ return 0
+}
+
+type encoding int
+
+const (
+ encodePath encoding = 1 + iota
+ encodeUserPassword
+ encodeQueryComponent
+ encodeFragment
+)
+
+type EscapeError string
+
+func (e EscapeError) Error() string {
+ return "invalid URL escape " + strconv.Quote(string(e))
+}
+
+// Return true if the specified character should be escaped when
+// appearing in a URL string, according to RFC 2396.
+// When 'all' is true the full range of reserved characters are matched.
+func shouldEscape(c byte, mode encoding) bool {
+ // RFC 2396 §2.3 Unreserved characters (alphanum)
+ if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' {
+ return false
+ }
+ // TODO: Update the character sets after RFC 3986.
+ switch c {
+ case '-', '_', '.', '!', '~', '*', '\'', '(', ')': // §2.3 Unreserved characters (mark)
+ return false
+
+ case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved)
+ // Different sections of the URL allow a few of
+ // the reserved characters to appear unescaped.
+ switch mode {
+ case encodePath: // §3.3
+ // The RFC allows : @ & = + $ but saves / ; , for assigning
+ // meaning to individual path segments. This package
+ // only manipulates the path as a whole, so we allow those
+ // last two as well. That leaves only ? to escape.
+ return c == '?'
+
+ case encodeUserPassword: // §3.2.2
+ // The RFC allows ; : & = + $ , in userinfo, so we must escape only @ and /.
+ // The parsing of userinfo treats : as special so we must escape that too.
+ return c == '@' || c == '/' || c == ':'
+
+ case encodeQueryComponent: // §3.4
+ // The RFC reserves (so we must escape) everything.
+ return true
+
+ case encodeFragment: // §4.1
+ // The RFC text is silent but the grammar allows
+ // everything, so escape nothing.
+ return false
+ }
+ }
+
+ // Everything else must be escaped.
+ return true
+}
+
+// QueryUnescape does the inverse transformation of QueryEscape, converting
+// %AB into the byte 0xAB and '+' into ' ' (space). It returns an error if
+// any % is not followed by two hexadecimal digits.
+func QueryUnescape(s string) (string, error) {
+ return unescape(s, encodeQueryComponent)
+}
+
+// unescape unescapes a string; the mode specifies
+// which section of the URL string is being unescaped.
+func unescape(s string, mode encoding) (string, error) {
+ // Count %, check that they're well-formed.
+ n := 0
+ hasPlus := false
+ for i := 0; i < len(s); {
+ switch s[i] {
+ case '%':
+ n++
+ if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
+ s = s[i:]
+ if len(s) > 3 {
+ s = s[0:3]
+ }
+ return "", EscapeError(s)
+ }
+ i += 3
+ case '+':
+ hasPlus = mode == encodeQueryComponent
+ i++
+ default:
+ i++
+ }
+ }
+
+ if n == 0 && !hasPlus {
+ return s, nil
+ }
+
+ t := make([]byte, len(s)-2*n)
+ j := 0
+ for i := 0; i < len(s); {
+ switch s[i] {
+ case '%':
+ t[j] = unhex(s[i+1])<<4 | unhex(s[i+2])
+ j++
+ i += 3
+ case '+':
+ if mode == encodeQueryComponent {
+ t[j] = ' '
+ } else {
+ t[j] = '+'
+ }
+ j++
+ i++
+ default:
+ t[j] = s[i]
+ j++
+ i++
+ }
+ }
+ return string(t), nil
+}
+
+// QueryEscape escapes the string so it can be safely placed
+// inside a URL query.
+func QueryEscape(s string) string {
+ return escape(s, encodeQueryComponent)
+}
+
+func escape(s string, mode encoding) string {
+ spaceCount, hexCount := 0, 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ if shouldEscape(c, mode) {
+ if c == ' ' && mode == encodeQueryComponent {
+ spaceCount++
+ } else {
+ hexCount++
+ }
+ }
+ }
+
+ if spaceCount == 0 && hexCount == 0 {
+ return s
+ }
+
+ t := make([]byte, len(s)+2*hexCount)
+ j := 0
+ for i := 0; i < len(s); i++ {
+ switch c := s[i]; {
+ case c == ' ' && mode == encodeQueryComponent:
+ t[j] = '+'
+ j++
+ case shouldEscape(c, mode):
+ t[j] = '%'
+ t[j+1] = "0123456789ABCDEF"[c>>4]
+ t[j+2] = "0123456789ABCDEF"[c&15]
+ j += 3
+ default:
+ t[j] = s[i]
+ j++
+ }
+ }
+ return string(t)
+}
+
+// A URL represents a parsed URL (technically, a URI reference).
+// The general form represented is:
+//
+// scheme://[userinfo@]host/path[?query][#fragment]
+//
+// URLs that do not start with a slash after the scheme are interpreted as:
+//
+// scheme:opaque[?query][#fragment]
+//
+type URL struct {
+ Scheme string
+ Opaque string // encoded opaque data
+ User *Userinfo // username and password information
+ Host string
+ Path string
+ RawQuery string // encoded query values, without '?'
+ Fragment string // fragment for references, without '#'
+}
+
+// User returns a Userinfo containing the provided username
+// and no password set.
+func User(username string) *Userinfo {
+ return &Userinfo{username, "", false}
+}
+
+// UserPassword returns a Userinfo containing the provided username
+// and password.
+// This functionality should only be used with legacy web sites.
+// RFC 2396 warns that interpreting Userinfo this way
+// ``is NOT RECOMMENDED, because the passing of authentication
+// information in clear text (such as URI) has proven to be a
+// security risk in almost every case where it has been used.''
+func UserPassword(username, password string) *Userinfo {
+ return &Userinfo{username, password, true}
+}
+
+// The Userinfo type is an immutable encapsulation of username and
+// password details for a URL. An existing Userinfo value is guaranteed
+// to have a username set (potentially empty, as allowed by RFC 2396),
+// and optionally a password.
+type Userinfo struct {
+ username string
+ password string
+ passwordSet bool
+}
+
+// Username returns the username.
+func (u *Userinfo) Username() string {
+ return u.username
+}
+
+// Password returns the password in case it is set, and whether it is set.
+func (u *Userinfo) Password() (string, bool) {
+ if u.passwordSet {
+ return u.password, true
+ }
+ return "", false
+}
+
+// String returns the encoded userinfo information in the standard form
+// of "username[:password]".
+func (u *Userinfo) String() string {
+ s := escape(u.username, encodeUserPassword)
+ if u.passwordSet {
+ s += ":" + escape(u.password, encodeUserPassword)
+ }
+ return s
+}
+
+// Maybe rawurl is of the form scheme:path.
+// (Scheme must be [a-zA-Z][a-zA-Z0-9+-.]*)
+// If so, return scheme, path; else return "", rawurl.
+func getscheme(rawurl string) (scheme, path string, err error) {
+ for i := 0; i < len(rawurl); i++ {
+ c := rawurl[i]
+ switch {
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z':
+ // do nothing
+ case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.':
+ if i == 0 {
+ return "", rawurl, nil
+ }
+ case c == ':':
+ if i == 0 {
+ return "", "", errors.New("missing protocol scheme")
+ }
+ return rawurl[0:i], rawurl[i+1:], nil
+ default:
+ // we have encountered an invalid character,
+ // so there is no valid scheme
+ return "", rawurl, nil
+ }
+ }
+ return "", rawurl, nil
+}
+
+// Maybe s is of the form t c u.
+// If so, return t, c u (or t, u if cutc == true).
+// If not, return s, "".
+func split(s string, c byte, cutc bool) (string, string) {
+ for i := 0; i < len(s); i++ {
+ if s[i] == c {
+ if cutc {
+ return s[0:i], s[i+1:]
+ }
+ return s[0:i], s[i:]
+ }
+ }
+ return s, ""
+}
+
+// Parse parses rawurl into a URL structure.
+// The string rawurl is assumed not to have a #fragment suffix.
+// (Web browsers strip #fragment before sending the URL to a web server.)
+// The rawurl may be relative or absolute.
+func Parse(rawurl string) (url *URL, err error) {
+ return parse(rawurl, false)
+}
+
+// ParseRequest parses rawurl into a URL structure. It assumes that
+// rawurl was received from an HTTP request, so the rawurl is interpreted
+// only as an absolute URI or an absolute path.
+// The string rawurl is assumed not to have a #fragment suffix.
+// (Web browsers strip #fragment before sending the URL to a web server.)
+func ParseRequest(rawurl string) (url *URL, err error) {
+ return parse(rawurl, true)
+}
+
+// parse parses a URL from a string in one of two contexts. If
+// viaRequest is true, the URL is assumed to have arrived via an HTTP request,
+// in which case only absolute URLs or path-absolute relative URLs are allowed.
+// If viaRequest is false, all forms of relative URLs are allowed.
+func parse(rawurl string, viaRequest bool) (url *URL, err error) {
+ var rest string
+
+ if rawurl == "" {
+ err = errors.New("empty url")
+ goto Error
+ }
+ url = new(URL)
+
+ // Split off possible leading "http:", "mailto:", etc.
+ // Cannot contain escaped characters.
+ if url.Scheme, rest, err = getscheme(rawurl); err != nil {
+ goto Error
+ }
+
+ rest, url.RawQuery = split(rest, '?', true)
+
+ if !strings.HasPrefix(rest, "/") {
+ if url.Scheme != "" {
+ // We consider rootless paths per RFC 3986 as opaque.
+ url.Opaque = rest
+ return url, nil
+ }
+ if viaRequest {
+ err = errors.New("invalid URI for request")
+ goto Error
+ }
+ }
+
+ if (url.Scheme != "" || !viaRequest) && strings.HasPrefix(rest, "//") && !strings.HasPrefix(rest, "///") {
+ var authority string
+ authority, rest = split(rest[2:], '/', false)
+ url.User, url.Host, err = parseAuthority(authority)
+ if err != nil {
+ goto Error
+ }
+ if strings.Contains(url.Host, "%") {
+ err = errors.New("hexadecimal escape in host")
+ goto Error
+ }
+ }
+ if url.Path, err = unescape(rest, encodePath); err != nil {
+ goto Error
+ }
+ return url, nil
+
+Error:
+ return nil, &Error{"parse", rawurl, err}
+}
+
+func parseAuthority(authority string) (user *Userinfo, host string, err error) {
+ if strings.Index(authority, "@") < 0 {
+ host = authority
+ return
+ }
+ userinfo, host := split(authority, '@', true)
+ if strings.Index(userinfo, ":") < 0 {
+ if userinfo, err = unescape(userinfo, encodeUserPassword); err != nil {
+ return
+ }
+ user = User(userinfo)
+ } else {
+ username, password := split(userinfo, ':', true)
+ if username, err = unescape(username, encodeUserPassword); err != nil {
+ return
+ }
+ if password, err = unescape(password, encodeUserPassword); err != nil {
+ return
+ }
+ user = UserPassword(username, password)
+ }
+ return
+}
+
+// ParseWithReference is like Parse but allows a trailing #fragment.
+func ParseWithReference(rawurlref string) (url *URL, err error) {
+ // Cut off #frag
+ rawurl, frag := split(rawurlref, '#', true)
+ if url, err = Parse(rawurl); err != nil {
+ return nil, err
+ }
+ if frag == "" {
+ return url, nil
+ }
+ if url.Fragment, err = unescape(frag, encodeFragment); err != nil {
+ return nil, &Error{"parse", rawurlref, err}
+ }
+ return url, nil
+}
+
+// String reassembles url into a valid URL string.
+func (url *URL) String() string {
+ // TODO: Rewrite to use bytes.Buffer
+ result := ""
+ if url.Scheme != "" {
+ result += url.Scheme + ":"
+ }
+ if url.Opaque != "" {
+ result += url.Opaque
+ } else {
+ if url.Host != "" || url.User != nil {
+ result += "//"
+ if u := url.User; u != nil {
+ result += u.String() + "@"
+ }
+ result += url.Host
+ }
+ result += escape(url.Path, encodePath)
+ }
+ if url.RawQuery != "" {
+ result += "?" + url.RawQuery
+ }
+ if url.Fragment != "" {
+ result += "#" + escape(url.Fragment, encodeFragment)
+ }
+ return result
+}
+
+// Values maps a string key to a list of values.
+// It is typically used for query parameters and form values.
+// Unlike in the http.Header map, the keys in a Values map
+// are case-sensitive.
+type Values map[string][]string
+
+// Get gets the first value associated with the given key.
+// If there are no values associated with the key, Get returns
+// the empty string. To access multiple values, use the map
+// directly.
+func (v Values) Get(key string) string {
+ if v == nil {
+ return ""
+ }
+ vs, ok := v[key]
+ if !ok || len(vs) == 0 {
+ return ""
+ }
+ return vs[0]
+}
+
+// Set sets the key to value. It replaces any existing
+// values.
+func (v Values) Set(key, value string) {
+ v[key] = []string{value}
+}
+
+// Add adds the key to value. It appends to any existing
+// values associated with key.
+func (v Values) Add(key, value string) {
+ v[key] = append(v[key], value)
+}
+
+// Del deletes the values associated with key.
+func (v Values) Del(key string) {
+ delete(v, key)
+}
+
+// ParseQuery parses the URL-encoded query string and returns
+// a map listing the values specified for each key.
+// ParseQuery always returns a non-nil map containing all the
+// valid query parameters found; err describes the first decoding error
+// encountered, if any.
+func ParseQuery(query string) (m Values, err error) {
+ m = make(Values)
+ err = parseQuery(m, query)
+ return
+}
+
+func parseQuery(m Values, query string) (err error) {
+ for query != "" {
+ key := query
+ if i := strings.IndexAny(key, "&;"); i >= 0 {
+ key, query = key[:i], key[i+1:]
+ } else {
+ query = ""
+ }
+ if key == "" {
+ continue
+ }
+ value := ""
+ if i := strings.Index(key, "="); i >= 0 {
+ key, value = key[:i], key[i+1:]
+ }
+ key, err1 := QueryUnescape(key)
+ if err1 != nil {
+ err = err1
+ continue
+ }
+ value, err1 = QueryUnescape(value)
+ if err1 != nil {
+ err = err1
+ continue
+ }
+ m[key] = append(m[key], value)
+ }
+ return err
+}
+
+// Encode encodes the values into ``URL encoded'' form.
+// e.g. "foo=bar&bar=baz"
+func (v Values) Encode() string {
+ if v == nil {
+ return ""
+ }
+ parts := make([]string, 0, len(v)) // will be large enough for most uses
+ for k, vs := range v {
+ prefix := QueryEscape(k) + "="
+ for _, v := range vs {
+ parts = append(parts, prefix+QueryEscape(v))
+ }
+ }
+ return strings.Join(parts, "&")
+}
+
+// resolvePath applies special path segments from refs and applies
+// them to base, per RFC 2396.
+func resolvePath(basepath string, refpath string) string {
+ base := strings.Split(basepath, "/")
+ refs := strings.Split(refpath, "/")
+ if len(base) == 0 {
+ base = []string{""}
+ }
+ for idx, ref := range refs {
+ switch {
+ case ref == ".":
+ base[len(base)-1] = ""
+ case ref == "..":
+ newLen := len(base) - 1
+ if newLen < 1 {
+ newLen = 1
+ }
+ base = base[0:newLen]
+ base[len(base)-1] = ""
+ default:
+ if idx == 0 || base[len(base)-1] == "" {
+ base[len(base)-1] = ref
+ } else {
+ base = append(base, ref)
+ }
+ }
+ }
+ return strings.Join(base, "/")
+}
+
+// IsAbs returns true if the URL is absolute.
+func (url *URL) IsAbs() bool {
+ return url.Scheme != ""
+}
+
+// Parse parses a URL in the context of a base URL. The URL in ref
+// may be relative or absolute. Parse returns nil, err on parse
+// failure, otherwise its return value is the same as ResolveReference.
+func (base *URL) Parse(ref string) (*URL, error) {
+ refurl, err := Parse(ref)
+ if err != nil {
+ return nil, err
+ }
+ return base.ResolveReference(refurl), nil
+}
+
+// ResolveReference resolves a URI reference to an absolute URI from
+// an absolute base URI, per RFC 2396 Section 5.2. The URI reference
+// may be relative or absolute. ResolveReference always returns a new
+// URL instance, even if the returned URL is identical to either the
+// base or reference. If ref is an absolute URL, then ResolveReference
+// ignores base and returns a copy of ref.
+func (base *URL) ResolveReference(ref *URL) *URL {
+ if ref.IsAbs() {
+ url := *ref
+ return &url
+ }
+ // relativeURI = ( net_path | abs_path | rel_path ) [ "?" query ]
+ url := *base
+ url.RawQuery = ref.RawQuery
+ url.Fragment = ref.Fragment
+ if ref.Opaque != "" {
+ url.Opaque = ref.Opaque
+ url.User = nil
+ url.Host = ""
+ url.Path = ""
+ return &url
+ }
+ if ref.Host != "" || ref.User != nil {
+ // The "net_path" case.
+ url.Host = ref.Host
+ url.User = ref.User
+ }
+ if strings.HasPrefix(ref.Path, "/") {
+ // The "abs_path" case.
+ url.Path = ref.Path
+ } else {
+ // The "rel_path" case.
+ path := resolvePath(base.Path, ref.Path)
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+ url.Path = path
+ }
+ return &url
+}
+
+// Query parses RawQuery and returns the corresponding values.
+func (u *URL) Query() Values {
+ v, _ := ParseQuery(u.RawQuery)
+ return v
+}
+
+// RequestURI returns the encoded path?query or opaque?query
+// string that would be used in an HTTP request for u.
+func (u *URL) RequestURI() string {
+ result := u.Opaque
+ if result == "" {
+ result = escape(u.Path, encodePath)
+ if result == "" {
+ result = "/"
+ }
+ }
+ if u.RawQuery != "" {
+ result += "?" + u.RawQuery
+ }
+ return result
+}
diff --git a/src/pkg/net/url/url_test.go b/src/pkg/net/url/url_test.go
new file mode 100644
index 000000000..9fe5ff886
--- /dev/null
+++ b/src/pkg/net/url/url_test.go
@@ -0,0 +1,771 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package url
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+)
+
+type URLTest struct {
+ in string
+ out *URL
+ roundtrip string // expected result of reserializing the URL; empty means same as "in".
+}
+
+var urltests = []URLTest{
+ // no path
+ {
+ "http://www.google.com",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ },
+ "",
+ },
+ // path
+ {
+ "http://www.google.com/",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ },
+ "",
+ },
+ // path with hex escaping
+ {
+ "http://www.google.com/file%20one%26two",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/file one&two",
+ },
+ "http://www.google.com/file%20one&two",
+ },
+ // user
+ {
+ "ftp://webmaster@www.google.com/",
+ &URL{
+ Scheme: "ftp",
+ User: User("webmaster"),
+ Host: "www.google.com",
+ Path: "/",
+ },
+ "",
+ },
+ // escape sequence in username
+ {
+ "ftp://john%20doe@www.google.com/",
+ &URL{
+ Scheme: "ftp",
+ User: User("john doe"),
+ Host: "www.google.com",
+ Path: "/",
+ },
+ "ftp://john%20doe@www.google.com/",
+ },
+ // query
+ {
+ "http://www.google.com/?q=go+language",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go+language",
+ },
+ "",
+ },
+ // query with hex escaping: NOT parsed
+ {
+ "http://www.google.com/?q=go%20language",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go%20language",
+ },
+ "",
+ },
+ // %20 outside query
+ {
+ "http://www.google.com/a%20b?q=c+d",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/a b",
+ RawQuery: "q=c+d",
+ },
+ "",
+ },
+ // path without leading /, so no parsing
+ {
+ "http:www.google.com/?q=go+language",
+ &URL{
+ Scheme: "http",
+ Opaque: "www.google.com/",
+ RawQuery: "q=go+language",
+ },
+ "http:www.google.com/?q=go+language",
+ },
+ // path without leading /, so no parsing
+ {
+ "http:%2f%2fwww.google.com/?q=go+language",
+ &URL{
+ Scheme: "http",
+ Opaque: "%2f%2fwww.google.com/",
+ RawQuery: "q=go+language",
+ },
+ "http:%2f%2fwww.google.com/?q=go+language",
+ },
+ // non-authority
+ {
+ "mailto:/webmaster@golang.org",
+ &URL{
+ Scheme: "mailto",
+ Path: "/webmaster@golang.org",
+ },
+ "",
+ },
+ // non-authority
+ {
+ "mailto:webmaster@golang.org",
+ &URL{
+ Scheme: "mailto",
+ Opaque: "webmaster@golang.org",
+ },
+ "",
+ },
+ // unescaped :// in query should not create a scheme
+ {
+ "/foo?query=http://bad",
+ &URL{
+ Path: "/foo",
+ RawQuery: "query=http://bad",
+ },
+ "",
+ },
+ // leading // without scheme should create an authority
+ {
+ "//foo",
+ &URL{
+ Host: "foo",
+ },
+ "",
+ },
+ // leading // without scheme, with userinfo, path, and query
+ {
+ "//user@foo/path?a=b",
+ &URL{
+ User: User("user"),
+ Host: "foo",
+ Path: "/path",
+ RawQuery: "a=b",
+ },
+ "",
+ },
+ // Three leading slashes isn't an authority, but doesn't return an error.
+ // (We can't return an error, as this code is also used via
+ // ServeHTTP -> ReadRequest -> Parse, which is arguably a
+ // different URL parsing context, but currently shares the
+ // same codepath)
+ {
+ "///threeslashes",
+ &URL{
+ Path: "///threeslashes",
+ },
+ "",
+ },
+ {
+ "http://user:password@google.com",
+ &URL{
+ Scheme: "http",
+ User: UserPassword("user", "password"),
+ Host: "google.com",
+ },
+ "http://user:password@google.com",
+ },
+}
+
+var urlnofragtests = []URLTest{
+ {
+ "http://www.google.com/?q=go+language#foo",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go+language#foo",
+ },
+ "",
+ },
+}
+
+var urlfragtests = []URLTest{
+ {
+ "http://www.google.com/?q=go+language#foo",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go+language",
+ Fragment: "foo",
+ },
+ "",
+ },
+ {
+ "http://www.google.com/?q=go+language#foo%26bar",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go+language",
+ Fragment: "foo&bar",
+ },
+ "http://www.google.com/?q=go+language#foo&bar",
+ },
+}
+
+// more useful string for debugging than fmt's struct printer
+func ufmt(u *URL) string {
+ var user, pass interface{}
+ if u.User != nil {
+ user = u.User.Username()
+ if p, ok := u.User.Password(); ok {
+ pass = p
+ }
+ }
+ return fmt.Sprintf("opaque=%q, scheme=%q, user=%#v, pass=%#v, host=%q, path=%q, rawq=%q, frag=%q",
+ u.Opaque, u.Scheme, user, pass, u.Host, u.Path, u.RawQuery, u.Fragment)
+}
+
+func DoTest(t *testing.T, parse func(string) (*URL, error), name string, tests []URLTest) {
+ for _, tt := range tests {
+ u, err := parse(tt.in)
+ if err != nil {
+ t.Errorf("%s(%q) returned error %s", name, tt.in, err)
+ continue
+ }
+ if !reflect.DeepEqual(u, tt.out) {
+ t.Errorf("%s(%q):\n\thave %v\n\twant %v\n",
+ name, tt.in, ufmt(u), ufmt(tt.out))
+ }
+ }
+}
+
+func TestParse(t *testing.T) {
+ DoTest(t, Parse, "Parse", urltests)
+ DoTest(t, Parse, "Parse", urlnofragtests)
+}
+
+func TestParseWithReference(t *testing.T) {
+ DoTest(t, ParseWithReference, "ParseWithReference", urltests)
+ DoTest(t, ParseWithReference, "ParseWithReference", urlfragtests)
+}
+
+const pathThatLooksSchemeRelative = "//not.a.user@not.a.host/just/a/path"
+
+var parseRequestUrlTests = []struct {
+ url string
+ expectedValid bool
+}{
+ {"http://foo.com", true},
+ {"http://foo.com/", true},
+ {"http://foo.com/path", true},
+ {"/", true},
+ {pathThatLooksSchemeRelative, true},
+ {"//not.a.user@%66%6f%6f.com/just/a/path/also", true},
+ {"foo.html", false},
+ {"../dir/", false},
+}
+
+func TestParseRequest(t *testing.T) {
+ for _, test := range parseRequestUrlTests {
+ _, err := ParseRequest(test.url)
+ valid := err == nil
+ if valid != test.expectedValid {
+ t.Errorf("Expected valid=%v for %q; got %v", test.expectedValid, test.url, valid)
+ }
+ }
+
+ url, err := ParseRequest(pathThatLooksSchemeRelative)
+ if err != nil {
+ t.Fatalf("Unexpected error %v", err)
+ }
+ if url.Path != pathThatLooksSchemeRelative {
+ t.Errorf("Expected path %q; got %q", pathThatLooksSchemeRelative, url.Path)
+ }
+}
+
+func DoTestString(t *testing.T, parse func(string) (*URL, error), name string, tests []URLTest) {
+ for _, tt := range tests {
+ u, err := parse(tt.in)
+ if err != nil {
+ t.Errorf("%s(%q) returned error %s", name, tt.in, err)
+ continue
+ }
+ expected := tt.in
+ if len(tt.roundtrip) > 0 {
+ expected = tt.roundtrip
+ }
+ s := u.String()
+ if s != expected {
+ t.Errorf("%s(%q).String() == %q (expected %q)", name, tt.in, s, expected)
+ }
+ }
+}
+
+func TestURLString(t *testing.T) {
+ DoTestString(t, Parse, "Parse", urltests)
+ DoTestString(t, Parse, "Parse", urlnofragtests)
+ DoTestString(t, ParseWithReference, "ParseWithReference", urltests)
+ DoTestString(t, ParseWithReference, "ParseWithReference", urlfragtests)
+}
+
+type EscapeTest struct {
+ in string
+ out string
+ err error
+}
+
+var unescapeTests = []EscapeTest{
+ {
+ "",
+ "",
+ nil,
+ },
+ {
+ "abc",
+ "abc",
+ nil,
+ },
+ {
+ "1%41",
+ "1A",
+ nil,
+ },
+ {
+ "1%41%42%43",
+ "1ABC",
+ nil,
+ },
+ {
+ "%4a",
+ "J",
+ nil,
+ },
+ {
+ "%6F",
+ "o",
+ nil,
+ },
+ {
+ "%", // not enough characters after %
+ "",
+ EscapeError("%"),
+ },
+ {
+ "%a", // not enough characters after %
+ "",
+ EscapeError("%a"),
+ },
+ {
+ "%1", // not enough characters after %
+ "",
+ EscapeError("%1"),
+ },
+ {
+ "123%45%6", // not enough characters after %
+ "",
+ EscapeError("%6"),
+ },
+ {
+ "%zzzzz", // invalid hex digits
+ "",
+ EscapeError("%zz"),
+ },
+}
+
+func TestUnescape(t *testing.T) {
+ for _, tt := range unescapeTests {
+ actual, err := QueryUnescape(tt.in)
+ if actual != tt.out || (err != nil) != (tt.err != nil) {
+ t.Errorf("QueryUnescape(%q) = %q, %s; want %q, %s", tt.in, actual, err, tt.out, tt.err)
+ }
+ }
+}
+
+var escapeTests = []EscapeTest{
+ {
+ "",
+ "",
+ nil,
+ },
+ {
+ "abc",
+ "abc",
+ nil,
+ },
+ {
+ "one two",
+ "one+two",
+ nil,
+ },
+ {
+ "10%",
+ "10%25",
+ nil,
+ },
+ {
+ " ?&=#+%!<>#\"{}|\\^[]`☺\t",
+ "+%3F%26%3D%23%2B%25!%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09",
+ nil,
+ },
+}
+
+func TestEscape(t *testing.T) {
+ for _, tt := range escapeTests {
+ actual := QueryEscape(tt.in)
+ if tt.out != actual {
+ t.Errorf("QueryEscape(%q) = %q, want %q", tt.in, actual, tt.out)
+ }
+
+ // for bonus points, verify that escape:unescape is an identity.
+ roundtrip, err := QueryUnescape(actual)
+ if roundtrip != tt.in || err != nil {
+ t.Errorf("QueryUnescape(%q) = %q, %s; want %q, %s", actual, roundtrip, err, tt.in, "[no error]")
+ }
+ }
+}
+
+//var userinfoTests = []UserinfoTest{
+// {"user", "password", "user:password"},
+// {"foo:bar", "~!@#$%^&*()_+{}|[]\\-=`:;'\"<>?,./",
+// "foo%3Abar:~!%40%23$%25%5E&*()_+%7B%7D%7C%5B%5D%5C-=%60%3A;'%22%3C%3E?,.%2F"},
+//}
+
+type EncodeQueryTest struct {
+ m Values
+ expected string
+ expected1 string
+}
+
+var encodeQueryTests = []EncodeQueryTest{
+ {nil, "", ""},
+ {Values{"q": {"puppies"}, "oe": {"utf8"}}, "q=puppies&oe=utf8", "oe=utf8&q=puppies"},
+ {Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7", "q=dogs&q=%26&q=7"},
+}
+
+func TestEncodeQuery(t *testing.T) {
+ for _, tt := range encodeQueryTests {
+ if q := tt.m.Encode(); q != tt.expected && q != tt.expected1 {
+ t.Errorf(`EncodeQuery(%+v) = %q, want %q`, tt.m, q, tt.expected)
+ }
+ }
+}
+
+var resolvePathTests = []struct {
+ base, ref, expected string
+}{
+ {"a/b", ".", "a/"},
+ {"a/b", "c", "a/c"},
+ {"a/b", "..", ""},
+ {"a/", "..", ""},
+ {"a/", "../..", ""},
+ {"a/b/c", "..", "a/"},
+ {"a/b/c", "../d", "a/d"},
+ {"a/b/c", ".././d", "a/d"},
+ {"a/b", "./..", ""},
+ {"a/./b", ".", "a/./"},
+ {"a/../", ".", "a/../"},
+ {"a/.././b", "c", "a/.././c"},
+}
+
+func TestResolvePath(t *testing.T) {
+ for _, test := range resolvePathTests {
+ got := resolvePath(test.base, test.ref)
+ if got != test.expected {
+ t.Errorf("For %q + %q got %q; expected %q", test.base, test.ref, got, test.expected)
+ }
+ }
+}
+
+var resolveReferenceTests = []struct {
+ base, rel, expected string
+}{
+ // Absolute URL references
+ {"http://foo.com?a=b", "https://bar.com/", "https://bar.com/"},
+ {"http://foo.com/", "https://bar.com/?a=b", "https://bar.com/?a=b"},
+ {"http://foo.com/bar", "mailto:foo@example.com", "mailto:foo@example.com"},
+
+ // Path-absolute references
+ {"http://foo.com/bar", "/baz", "http://foo.com/baz"},
+ {"http://foo.com/bar?a=b#f", "/baz", "http://foo.com/baz"},
+ {"http://foo.com/bar?a=b", "/baz?c=d", "http://foo.com/baz?c=d"},
+
+ // Scheme-relative
+ {"https://foo.com/bar?a=b", "//bar.com/quux", "https://bar.com/quux"},
+
+ // Path-relative references:
+
+ // ... current directory
+ {"http://foo.com", ".", "http://foo.com/"},
+ {"http://foo.com/bar", ".", "http://foo.com/"},
+ {"http://foo.com/bar/", ".", "http://foo.com/bar/"},
+
+ // ... going down
+ {"http://foo.com", "bar", "http://foo.com/bar"},
+ {"http://foo.com/", "bar", "http://foo.com/bar"},
+ {"http://foo.com/bar/baz", "quux", "http://foo.com/bar/quux"},
+
+ // ... going up
+ {"http://foo.com/bar/baz", "../quux", "http://foo.com/quux"},
+ {"http://foo.com/bar/baz", "../../../../../quux", "http://foo.com/quux"},
+ {"http://foo.com/bar", "..", "http://foo.com/"},
+ {"http://foo.com/bar/baz", "./..", "http://foo.com/"},
+
+ // "." and ".." in the base aren't special
+ {"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/./dotdot/../baz"},
+
+ // Triple dot isn't special
+ {"http://foo.com/bar", "...", "http://foo.com/..."},
+
+ // Fragment
+ {"http://foo.com/bar", ".#frag", "http://foo.com/#frag"},
+}
+
+func TestResolveReference(t *testing.T) {
+ mustParse := func(url string) *URL {
+ u, err := ParseWithReference(url)
+ if err != nil {
+ t.Fatalf("Expected URL to parse: %q, got error: %v", url, err)
+ }
+ return u
+ }
+ for _, test := range resolveReferenceTests {
+ base := mustParse(test.base)
+ rel := mustParse(test.rel)
+ url := base.ResolveReference(rel)
+ urlStr := url.String()
+ if urlStr != test.expected {
+ t.Errorf("Resolving %q + %q != %q; got %q", test.base, test.rel, test.expected, urlStr)
+ }
+ }
+
+ // Test that new instances are returned.
+ base := mustParse("http://foo.com/")
+ abs := base.ResolveReference(mustParse("."))
+ if base == abs {
+ t.Errorf("Expected no-op reference to return new URL instance.")
+ }
+ barRef := mustParse("http://bar.com/")
+ abs = base.ResolveReference(barRef)
+ if abs == barRef {
+ t.Errorf("Expected resolution of absolute reference to return new URL instance.")
+ }
+
+ // Test the convenience wrapper too
+ base = mustParse("http://foo.com/path/one/")
+ abs, _ = base.Parse("../two")
+ expected := "http://foo.com/path/two"
+ if abs.String() != expected {
+ t.Errorf("Parse wrapper got %q; expected %q", abs.String(), expected)
+ }
+ _, err := base.Parse("")
+ if err == nil {
+ t.Errorf("Expected an error from Parse wrapper parsing an empty string.")
+ }
+
+ // Ensure Opaque resets the URL.
+ base = mustParse("scheme://user@foo.com/bar")
+ abs = base.ResolveReference(&URL{Opaque: "opaque"})
+ want := mustParse("scheme:opaque")
+ if *abs != *want {
+ t.Errorf("ResolveReference failed to resolve opaque URL: want %#v, got %#v", abs, want)
+ }
+}
+
+func TestResolveReferenceOpaque(t *testing.T) {
+ mustParse := func(url string) *URL {
+ u, err := ParseWithReference(url)
+ if err != nil {
+ t.Fatalf("Expected URL to parse: %q, got error: %v", url, err)
+ }
+ return u
+ }
+ for _, test := range resolveReferenceTests {
+ base := mustParse(test.base)
+ rel := mustParse(test.rel)
+ url := base.ResolveReference(rel)
+ urlStr := url.String()
+ if urlStr != test.expected {
+ t.Errorf("Resolving %q + %q != %q; got %q", test.base, test.rel, test.expected, urlStr)
+ }
+ }
+
+ // Test that new instances are returned.
+ base := mustParse("http://foo.com/")
+ abs := base.ResolveReference(mustParse("."))
+ if base == abs {
+ t.Errorf("Expected no-op reference to return new URL instance.")
+ }
+ barRef := mustParse("http://bar.com/")
+ abs = base.ResolveReference(barRef)
+ if abs == barRef {
+ t.Errorf("Expected resolution of absolute reference to return new URL instance.")
+ }
+
+ // Test the convenience wrapper too
+ base = mustParse("http://foo.com/path/one/")
+ abs, _ = base.Parse("../two")
+ expected := "http://foo.com/path/two"
+ if abs.String() != expected {
+ t.Errorf("Parse wrapper got %q; expected %q", abs.String(), expected)
+ }
+ _, err := base.Parse("")
+ if err == nil {
+ t.Errorf("Expected an error from Parse wrapper parsing an empty string.")
+ }
+
+}
+
+func TestQueryValues(t *testing.T) {
+ u, _ := Parse("http://x.com?foo=bar&bar=1&bar=2")
+ v := u.Query()
+ if len(v) != 2 {
+ t.Errorf("got %d keys in Query values, want 2", len(v))
+ }
+ if g, e := v.Get("foo"), "bar"; g != e {
+ t.Errorf("Get(foo) = %q, want %q", g, e)
+ }
+ // Case sensitive:
+ if g, e := v.Get("Foo"), ""; g != e {
+ t.Errorf("Get(Foo) = %q, want %q", g, e)
+ }
+ if g, e := v.Get("bar"), "1"; g != e {
+ t.Errorf("Get(bar) = %q, want %q", g, e)
+ }
+ if g, e := v.Get("baz"), ""; g != e {
+ t.Errorf("Get(baz) = %q, want %q", g, e)
+ }
+ v.Del("bar")
+ if g, e := v.Get("bar"), ""; g != e {
+ t.Errorf("second Get(bar) = %q, want %q", g, e)
+ }
+}
+
+type parseTest struct {
+ query string
+ out Values
+}
+
+var parseTests = []parseTest{
+ {
+ query: "a=1&b=2",
+ out: Values{"a": []string{"1"}, "b": []string{"2"}},
+ },
+ {
+ query: "a=1&a=2&a=banana",
+ out: Values{"a": []string{"1", "2", "banana"}},
+ },
+ {
+ query: "ascii=%3Ckey%3A+0x90%3E",
+ out: Values{"ascii": []string{"<key: 0x90>"}},
+ },
+ {
+ query: "a=1;b=2",
+ out: Values{"a": []string{"1"}, "b": []string{"2"}},
+ },
+ {
+ query: "a=1&a=2;a=banana",
+ out: Values{"a": []string{"1", "2", "banana"}},
+ },
+}
+
+func TestParseQuery(t *testing.T) {
+ for i, test := range parseTests {
+ form, err := ParseQuery(test.query)
+ if err != nil {
+ t.Errorf("test %d: Unexpected error: %v", i, err)
+ continue
+ }
+ if len(form) != len(test.out) {
+ t.Errorf("test %d: len(form) = %d, want %d", i, len(form), len(test.out))
+ }
+ for k, evs := range test.out {
+ vs, ok := form[k]
+ if !ok {
+ t.Errorf("test %d: Missing key %q", i, k)
+ continue
+ }
+ if len(vs) != len(evs) {
+ t.Errorf("test %d: len(form[%q]) = %d, want %d", i, k, len(vs), len(evs))
+ continue
+ }
+ for j, ev := range evs {
+ if v := vs[j]; v != ev {
+ t.Errorf("test %d: form[%q][%d] = %q, want %q", i, k, j, v, ev)
+ }
+ }
+ }
+ }
+}
+
+type RequestURITest struct {
+ url *URL
+ out string
+}
+
+var requritests = []RequestURITest{
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "",
+ },
+ "/",
+ },
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/a b",
+ },
+ "/a%20b",
+ },
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/a b",
+ RawQuery: "q=go+language",
+ },
+ "/a%20b?q=go+language",
+ },
+ {
+ &URL{
+ Scheme: "myschema",
+ Opaque: "opaque",
+ },
+ "opaque",
+ },
+ {
+ &URL{
+ Scheme: "myschema",
+ Opaque: "opaque",
+ RawQuery: "q=go+language",
+ },
+ "opaque?q=go+language",
+ },
+}
+
+func TestRequestURI(t *testing.T) {
+ for _, tt := range requritests {
+ s := tt.url.RequestURI()
+ if s != tt.out {
+ t.Errorf("%#v.RequestURI() == %q (expected %q)", tt.url, s, tt.out)
+ }
+ }
+}