diff options
Diffstat (limited to 'src/pkg/net')
131 files changed, 5243 insertions, 1202 deletions
diff --git a/src/pkg/net/cgo_bsd.go b/src/pkg/net/cgo_bsd.go index 388eab4fe..3090d3019 100644 --- a/src/pkg/net/cgo_bsd.go +++ b/src/pkg/net/cgo_bsd.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // +build !netgo -// +build darwin dragonfly freebsd +// +build darwin dragonfly freebsd solaris package net diff --git a/src/pkg/net/cgo_unix_test.go b/src/pkg/net/cgo_unix_test.go new file mode 100644 index 000000000..33566ce9c --- /dev/null +++ b/src/pkg/net/cgo_unix_test.go @@ -0,0 +1,24 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build cgo,!netgo +// +build darwin dragonfly freebsd linux netbsd openbsd + +package net + +import "testing" + +func TestCgoLookupIP(t *testing.T) { + host := "localhost" + _, err, ok := cgoLookupIP(host) + if !ok { + t.Errorf("cgoLookupIP must not be a placeholder") + } + if err != nil { + t.Errorf("cgoLookupIP failed: %v", err) + } + if _, err := goLookupIP(host); err != nil { + t.Errorf("goLookupIP failed: %v", err) + } +} diff --git a/src/pkg/net/conn_test.go b/src/pkg/net/conn_test.go index 98bd69549..37bb4e2c0 100644 --- a/src/pkg/net/conn_test.go +++ b/src/pkg/net/conn_test.go @@ -16,11 +16,11 @@ import ( var connTests = []struct { net string - addr func() string + addr string }{ - {"tcp", func() string { return "127.0.0.1:0" }}, - {"unix", testUnixAddr}, - {"unixpacket", testUnixAddr}, + {"tcp", "127.0.0.1:0"}, + {"unix", testUnixAddr()}, + {"unixpacket", testUnixAddr()}, } // someTimeout is used just to test that net.Conn implementations @@ -31,18 +31,21 @@ const someTimeout = 10 * time.Second func TestConnAndListener(t *testing.T) { for _, tt := range connTests { switch tt.net { - case "unix", "unixpacket": + case "unix": switch runtime.GOOS { - case "plan9", "windows": + case "nacl", "plan9", "windows": continue } - if tt.net == "unixpacket" && runtime.GOOS != "linux" { + case "unixpacket": + switch runtime.GOOS { + case "darwin", "nacl", "openbsd", "plan9", "windows": + continue + case "freebsd": // FreeBSD 8 doesn't support unixpacket continue } } - addr := tt.addr() - ln, err := Listen(tt.net, addr) + ln, err := Listen(tt.net, tt.addr) if err != nil { t.Fatalf("Listen failed: %v", err) } @@ -52,8 +55,10 @@ func TestConnAndListener(t *testing.T) { case "unix", "unixpacket": os.Remove(addr) } - }(ln, tt.net, addr) - ln.Addr() + }(ln, tt.net, tt.addr) + if ln.Addr().Network() != tt.net { + t.Fatalf("got %v; expected %v", ln.Addr().Network(), tt.net) + } done := make(chan int) go transponder(t, ln, done) @@ -63,8 +68,9 @@ func TestConnAndListener(t *testing.T) { t.Fatalf("Dial failed: %v", err) } defer c.Close() - c.LocalAddr() - c.RemoteAddr() + if c.LocalAddr().Network() != tt.net || c.LocalAddr().Network() != tt.net { + t.Fatalf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), tt.net, tt.net) + } c.SetDeadline(time.Now().Add(someTimeout)) c.SetReadDeadline(time.Now().Add(someTimeout)) c.SetWriteDeadline(time.Now().Add(someTimeout)) @@ -96,8 +102,11 @@ func transponder(t *testing.T, ln Listener, done chan<- int) { return } defer c.Close() - c.LocalAddr() - c.RemoteAddr() + network := ln.Addr().Network() + if c.LocalAddr().Network() != network || c.LocalAddr().Network() != network { + t.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network) + return + } c.SetDeadline(time.Now().Add(someTimeout)) c.SetReadDeadline(time.Now().Add(someTimeout)) c.SetWriteDeadline(time.Now().Add(someTimeout)) diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index 6304818bf..93569c253 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -44,6 +44,12 @@ type Dialer struct { // destination is a host name that has multiple address family // DNS records. DualStack bool + + // KeepAlive specifies the keep-alive period for an active + // network connection. + // If zero, keep-alives are not enabled. Network protocols + // that do not support keep-alives ignore this field. + KeepAlive time.Duration } // Return either now+Timeout or Deadline, whichever comes first. @@ -162,9 +168,19 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { return dialMulti(network, address, d.LocalAddr, ras, deadline) } } - return dial(network, ra.toAddr(), dialer, d.deadline()) + c, err := dial(network, ra.toAddr(), dialer, d.deadline()) + if d.KeepAlive > 0 && err == nil { + if tc, ok := c.(*TCPConn); ok { + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(d.KeepAlive) + testHookSetKeepAlive() + } + } + return c, err } +var testHookSetKeepAlive = func() {} // changed by dial_test.go + // dialMulti attempts to establish connections to each destination of // the list of addresses. It will return the first established // connection and close the other connections. Otherwise it returns @@ -172,7 +188,6 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { func dialMulti(net, addr string, la Addr, ras addrList, deadline time.Time) (Conn, error) { type racer struct { Conn - Addr error } // Sig controls the flow of dial results on lane. It passes a @@ -184,7 +199,7 @@ func dialMulti(net, addr string, la Addr, ras addrList, deadline time.Time) (Con go func(ra Addr) { c, err := dialSingle(net, addr, la, ra, deadline) if _, ok := <-sig; ok { - lane <- racer{c, ra, err} + lane <- racer{c, err} } else if err == nil { // We have to return the resources // that belong to the other @@ -195,7 +210,6 @@ func dialMulti(net, addr string, la Addr, ras addrList, deadline time.Time) (Con }(ra.toAddr()) } defer close(sig) - var failAddr Addr lastErr := errTimeout nracers := len(ras) for nracers > 0 { @@ -205,12 +219,11 @@ func dialMulti(net, addr string, la Addr, ras addrList, deadline time.Time) (Con if racer.error == nil { return racer.Conn, nil } - failAddr = racer.Addr lastErr = racer.error nracers-- } } - return nil, &OpError{Op: "dial", Net: net, Addr: failAddr, Err: lastErr} + return nil, lastErr } // dialSingle attempts to establish and returns a single connection to diff --git a/src/pkg/net/dial_test.go b/src/pkg/net/dial_test.go index f1d813f41..f9260fd28 100644 --- a/src/pkg/net/dial_test.go +++ b/src/pkg/net/dial_test.go @@ -58,7 +58,7 @@ func TestDialTimeout(t *testing.T) { errc <- err }() } - case "darwin", "windows": + case "darwin", "plan9", "windows": // At least OS X 10.7 seems to accept any number of // connections, ignoring listen's backlog, so resort // to connecting to a hopefully-dead 127/8 address. @@ -141,13 +141,13 @@ func TestSelfConnect(t *testing.T) { n = 1000 } switch runtime.GOOS { - case "darwin", "dragonfly", "freebsd", "netbsd", "openbsd", "plan9", "windows": + case "darwin", "dragonfly", "freebsd", "netbsd", "openbsd", "plan9", "solaris", "windows": // Non-Linux systems take a long time to figure // out that there is nothing listening on localhost. n = 100 } for i := 0; i < n; i++ { - c, err := Dial("tcp", addr) + c, err := DialTimeout("tcp", addr, time.Millisecond) if err == nil { c.Close() t.Errorf("#%d: Dial %q succeeded", i, addr) @@ -425,60 +425,6 @@ func numFD() int { panic("numFDs not implemented on " + runtime.GOOS) } -// Assert that a failed Dial attempt does not leak -// runtime.PollDesc structures -func TestDialFailPDLeak(t *testing.T) { - if testing.Short() { - t.Skip("skipping test in short mode") - } - if runtime.GOOS == "windows" && runtime.GOARCH == "386" { - // Just skip the test because it takes too long. - t.Skipf("skipping test on %q/%q", runtime.GOOS, runtime.GOARCH) - } - - maxprocs := runtime.GOMAXPROCS(0) - loops := 10 + maxprocs - // 500 is enough to turn over the chunk of pollcache. - // See allocPollDesc in runtime/netpoll.goc. - const count = 500 - var old runtime.MemStats // used by sysdelta - runtime.ReadMemStats(&old) - sysdelta := func() uint64 { - var new runtime.MemStats - runtime.ReadMemStats(&new) - delta := old.Sys - new.Sys - old = new - return delta - } - d := &Dialer{Timeout: time.Nanosecond} // don't bother TCP with handshaking - failcount := 0 - for i := 0; i < loops; i++ { - var wg sync.WaitGroup - for i := 0; i < count; i++ { - wg.Add(1) - go func() { - defer wg.Done() - if c, err := d.Dial("tcp", "127.0.0.1:1"); err == nil { - t.Error("dial should not succeed") - c.Close() - } - }() - } - wg.Wait() - if t.Failed() { - t.FailNow() - } - if delta := sysdelta(); delta > 0 { - failcount++ - } - // there are always some allocations on the first loop - if failcount > maxprocs+2 { - t.Error("detected possible memory leak in runtime") - t.FailNow() - } - } -} - func TestDialer(t *testing.T) { ln, err := Listen("tcp4", "127.0.0.1:0") if err != nil { @@ -555,3 +501,36 @@ func TestDialDualStackLocalhost(t *testing.T) { } } } + +func TestDialerKeepAlive(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + defer func() { + testHookSetKeepAlive = func() {} + }() + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + c.Close() + } + }() + for _, keepAlive := range []bool{false, true} { + got := false + testHookSetKeepAlive = func() { got = true } + var d Dialer + if keepAlive { + d.KeepAlive = 30 * time.Second + } + c, err := d.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + if got != keepAlive { + t.Errorf("Dialer.KeepAlive = %v: SetKeepAlive called = %v, want %v", d.KeepAlive, got, !got) + } + } +} diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go index b4ebad0e0..df5895afa 100644 --- a/src/pkg/net/dialgoogle_test.go +++ b/src/pkg/net/dialgoogle_test.go @@ -104,31 +104,7 @@ var googleaddrsipv4 = []string{ "[::ffff:%02x%02x:%02x%02x]:80", "[0:0:0:0:0000:ffff:%d.%d.%d.%d]:80", "[0:0:0:0:000000:ffff:%d.%d.%d.%d]:80", - "[0:0:0:0:0:ffff::%d.%d.%d.%d]:80", -} - -func TestDNSThreadLimit(t *testing.T) { - if testing.Short() || !*testExternal { - t.Skip("skipping test to avoid external network") - } - - const N = 10000 - c := make(chan int, N) - for i := 0; i < N; i++ { - go func(i int) { - LookupIP(fmt.Sprintf("%d.net-test.golang.org", i)) - c <- 1 - }(i) - } - // Don't bother waiting for the stragglers; stop at 0.9 N. - for i := 0; i < N*9/10; i++ { - if i%100 == 0 { - //println("TestDNSThreadLimit:", i) - } - <-c - } - - // If we're still here, it worked. + "[0:0:0:0::ffff:%d.%d.%d.%d]:80", } func TestDialGoogleIPv4(t *testing.T) { diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go index 01db43729..9bffa11f9 100644 --- a/src/pkg/net/dnsclient.go +++ b/src/pkg/net/dnsclient.go @@ -191,10 +191,10 @@ func (addrs byPriorityWeight) shuffleByWeight() { } for sum > 0 && len(addrs) > 1 { s := 0 - n := rand.Intn(sum + 1) + n := rand.Intn(sum) for i := range addrs { s += int(addrs[i].Weight) - if s >= n { + if s > n { if i > 0 { t := addrs[i] copy(addrs[1:i+1], addrs[0:i]) diff --git a/src/pkg/net/dnsclient_test.go b/src/pkg/net/dnsclient_test.go new file mode 100644 index 000000000..435eb3550 --- /dev/null +++ b/src/pkg/net/dnsclient_test.go @@ -0,0 +1,69 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "math/rand" + "testing" +) + +func checkDistribution(t *testing.T, data []*SRV, margin float64) { + sum := 0 + for _, srv := range data { + sum += int(srv.Weight) + } + + results := make(map[string]int) + + count := 1000 + for j := 0; j < count; j++ { + d := make([]*SRV, len(data)) + copy(d, data) + byPriorityWeight(d).shuffleByWeight() + key := d[0].Target + results[key] = results[key] + 1 + } + + actual := results[data[0].Target] + expected := float64(count) * float64(data[0].Weight) / float64(sum) + diff := float64(actual) - expected + t.Logf("actual: %v diff: %v e: %v m: %v", actual, diff, expected, margin) + if diff < 0 { + diff = -diff + } + if diff > (expected * margin) { + t.Errorf("missed target weight: expected %v, %v", expected, actual) + } +} + +func testUniformity(t *testing.T, size int, margin float64) { + rand.Seed(1) + data := make([]*SRV, size) + for i := 0; i < size; i++ { + data[i] = &SRV{Target: string('a' + i), Weight: 1} + } + checkDistribution(t, data, margin) +} + +func TestUniformity(t *testing.T) { + testUniformity(t, 2, 0.05) + testUniformity(t, 3, 0.10) + testUniformity(t, 10, 0.20) + testWeighting(t, 0.05) +} + +func testWeighting(t *testing.T, margin float64) { + rand.Seed(1) + data := []*SRV{ + {Target: "a", Weight: 60}, + {Target: "b", Weight: 30}, + {Target: "c", Weight: 10}, + } + checkDistribution(t, data, margin) +} + +func TestWeighting(t *testing.T) { + testWeighting(t, 0.05) +} diff --git a/src/pkg/net/dnsclient_unix.go b/src/pkg/net/dnsclient_unix.go index 16cf420dc..3713efd0e 100644 --- a/src/pkg/net/dnsclient_unix.go +++ b/src/pkg/net/dnsclient_unix.go @@ -2,13 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris // DNS client: see RFC 1035. // Has to be linked into package net for Dial. // TODO(rsc): -// Check periodically whether /etc/resolv.conf has changed. // Could potentially handle many outstanding lookups faster. // Could have a small cache. // Random UDP source port (net.Dial should do that for us). @@ -19,6 +18,7 @@ package net import ( "io" "math/rand" + "os" "sync" "time" ) @@ -156,32 +156,90 @@ func convertRR_AAAA(records []dnsRR) []IP { return addrs } -var cfg *dnsConfig -var dnserr error +var cfg struct { + ch chan struct{} + mu sync.RWMutex // protects dnsConfig and dnserr + dnsConfig *dnsConfig + dnserr error +} +var onceLoadConfig sync.Once -func loadConfig() { cfg, dnserr = dnsReadConfig() } +// Assume dns config file is /etc/resolv.conf here +func loadDefaultConfig() { + loadConfig("/etc/resolv.conf", 5*time.Second, nil) +} -var onceLoadConfig sync.Once +func loadConfig(resolvConfPath string, reloadTime time.Duration, quit <-chan chan struct{}) { + var mtime time.Time + cfg.ch = make(chan struct{}, 1) + if fi, err := os.Stat(resolvConfPath); err != nil { + cfg.dnserr = err + } else { + mtime = fi.ModTime() + cfg.dnsConfig, cfg.dnserr = dnsReadConfig(resolvConfPath) + } + go func() { + for { + time.Sleep(reloadTime) + select { + case qresp := <-quit: + qresp <- struct{}{} + return + case <-cfg.ch: + } + + // In case of error, we keep the previous config + fi, err := os.Stat(resolvConfPath) + if err != nil { + continue + } + // If the resolv.conf mtime didn't change, do not reload + m := fi.ModTime() + if m.Equal(mtime) { + continue + } + mtime = m + // In case of error, we keep the previous config + ncfg, err := dnsReadConfig(resolvConfPath) + if err != nil || len(ncfg.servers) == 0 { + continue + } + cfg.mu.Lock() + cfg.dnsConfig = ncfg + cfg.dnserr = nil + cfg.mu.Unlock() + } + }() +} func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err error) { if !isDomainName(name) { return name, nil, &DNSError{Err: "invalid domain name", Name: name} } - onceLoadConfig.Do(loadConfig) - if dnserr != nil || cfg == nil { - err = dnserr + onceLoadConfig.Do(loadDefaultConfig) + + select { + case cfg.ch <- struct{}{}: + default: + } + + cfg.mu.RLock() + defer cfg.mu.RUnlock() + + if cfg.dnserr != nil || cfg.dnsConfig == nil { + err = cfg.dnserr return } // If name is rooted (trailing dot) or has enough dots, // try it by itself first. rooted := len(name) > 0 && name[len(name)-1] == '.' - if rooted || count(name, '.') >= cfg.ndots { + if rooted || count(name, '.') >= cfg.dnsConfig.ndots { rname := name if !rooted { rname += "." } // Can try as ordinary name. - cname, addrs, err = tryOneName(cfg, rname, qtype) + cname, addrs, err = tryOneName(cfg.dnsConfig, rname, qtype) if err == nil { return } @@ -191,12 +249,12 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err error) } // Otherwise, try suffixes. - for i := 0; i < len(cfg.search); i++ { - rname := name + "." + cfg.search[i] + for i := 0; i < len(cfg.dnsConfig.search); i++ { + rname := name + "." + cfg.dnsConfig.search[i] if rname[len(rname)-1] != '.' { rname += "." } - cname, addrs, err = tryOneName(cfg, rname, qtype) + cname, addrs, err = tryOneName(cfg.dnsConfig, rname, qtype) if err == nil { return } @@ -207,7 +265,7 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err error) if !rooted { rname += "." } - cname, addrs, err = tryOneName(cfg, rname, qtype) + cname, addrs, err = tryOneName(cfg.dnsConfig, rname, qtype) if err == nil { return } @@ -232,11 +290,6 @@ func goLookupHost(name string) (addrs []string, err error) { if len(addrs) > 0 { return } - onceLoadConfig.Do(loadConfig) - if dnserr != nil || cfg == nil { - err = dnserr - return - } ips, err := goLookupIP(name) if err != nil { return @@ -267,11 +320,6 @@ func goLookupIP(name string) (addrs []IP, err error) { return } } - onceLoadConfig.Do(loadConfig) - if dnserr != nil || cfg == nil { - err = dnserr - return - } var records []dnsRR var cname string var err4, err6 error @@ -307,11 +355,6 @@ func goLookupIP(name string) (addrs []IP, err error) { // depending on our lookup code, so that Go and C get the same // answers. func goLookupCNAME(name string) (cname string, err error) { - onceLoadConfig.Do(loadConfig) - if dnserr != nil || cfg == nil { - err = dnserr - return - } _, rr, err := lookup(name, dnsTypeCNAME) if err != nil { return diff --git a/src/pkg/net/dnsclient_unix_test.go b/src/pkg/net/dnsclient_unix_test.go index 47dcb563b..2350142d6 100644 --- a/src/pkg/net/dnsclient_unix_test.go +++ b/src/pkg/net/dnsclient_unix_test.go @@ -2,12 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux netbsd openbsd solaris package net import ( + "io" + "io/ioutil" + "os" + "path" + "reflect" "testing" + "time" ) func TestTCPLookup(t *testing.T) { @@ -25,3 +31,129 @@ func TestTCPLookup(t *testing.T) { t.Fatalf("exchange failed: %v", err) } } + +type resolvConfTest struct { + *testing.T + dir string + path string + started bool + quitc chan chan struct{} +} + +func newResolvConfTest(t *testing.T) *resolvConfTest { + dir, err := ioutil.TempDir("", "resolvConfTest") + if err != nil { + t.Fatalf("could not create temp dir: %v", err) + } + + // Disable the default loadConfig + onceLoadConfig.Do(func() {}) + + r := &resolvConfTest{ + T: t, + dir: dir, + path: path.Join(dir, "resolv.conf"), + quitc: make(chan chan struct{}), + } + + return r +} + +func (r *resolvConfTest) Start() { + loadConfig(r.path, 100*time.Millisecond, r.quitc) + r.started = true +} + +func (r *resolvConfTest) SetConf(s string) { + // Make sure the file mtime will be different once we're done here, + // even on systems with coarse (1s) mtime resolution. + time.Sleep(time.Second) + + f, err := os.OpenFile(r.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + r.Fatalf("failed to create temp file %s: %v", r.path, err) + } + if _, err := io.WriteString(f, s); err != nil { + f.Close() + r.Fatalf("failed to write temp file: %v", err) + } + f.Close() + + if r.started { + cfg.ch <- struct{}{} // fill buffer + cfg.ch <- struct{}{} // wait for reload to begin + cfg.ch <- struct{}{} // wait for reload to complete + } +} + +func (r *resolvConfTest) WantServers(want []string) { + cfg.mu.RLock() + defer cfg.mu.RUnlock() + if got := cfg.dnsConfig.servers; !reflect.DeepEqual(got, want) { + r.Fatalf("Unexpected dns server loaded, got %v want %v", got, want) + } +} + +func (r *resolvConfTest) Close() { + resp := make(chan struct{}) + r.quitc <- resp + <-resp + if err := os.RemoveAll(r.dir); err != nil { + r.Logf("failed to remove temp dir %s: %v", r.dir, err) + } +} + +func TestReloadResolvConfFail(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + + r := newResolvConfTest(t) + defer r.Close() + + // resolv.conf.tmp does not exist yet + r.Start() + if _, err := goLookupIP("golang.org"); err == nil { + t.Fatal("goLookupIP(missing) succeeded") + } + + r.SetConf("nameserver 8.8.8.8") + if _, err := goLookupIP("golang.org"); err != nil { + t.Fatalf("goLookupIP(missing; good) failed: %v", err) + } + + // Using a bad resolv.conf while we had a good + // one before should not update the config + r.SetConf("") + if _, err := goLookupIP("golang.org"); err != nil { + t.Fatalf("goLookupIP(missing; good; bad) failed: %v", err) + } +} + +func TestReloadResolvConfChange(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + + r := newResolvConfTest(t) + defer r.Close() + + r.SetConf("nameserver 8.8.8.8") + r.Start() + + if _, err := goLookupIP("golang.org"); err != nil { + t.Fatalf("goLookupIP(good) failed: %v", err) + } + r.WantServers([]string{"[8.8.8.8]"}) + + // Using a bad resolv.conf when we had a good one + // before should not update the config + r.SetConf("") + if _, err := goLookupIP("golang.org"); err != nil { + t.Fatalf("goLookupIP(good; bad) failed: %v", err) + } + + // A new good config should get picked up + r.SetConf("nameserver 8.8.4.4") + r.WantServers([]string{"[8.8.4.4]"}) +} diff --git a/src/pkg/net/dnsconfig_unix.go b/src/pkg/net/dnsconfig_unix.go index 2f0f6c031..af288253e 100644 --- a/src/pkg/net/dnsconfig_unix.go +++ b/src/pkg/net/dnsconfig_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris // Read system DNS config from /etc/resolv.conf @@ -20,14 +20,13 @@ type dnsConfig struct { // See resolv.conf(5) on a Linux machine. // TODO(rsc): Supposed to call uname() and chop the beginning // of the host name to get the default search domain. -// We assume it's in resolv.conf anyway. -func dnsReadConfig() (*dnsConfig, error) { - file, err := open("/etc/resolv.conf") +func dnsReadConfig(filename string) (*dnsConfig, error) { + file, err := open(filename) if err != nil { return nil, &DNSConfigError{err} } conf := new(dnsConfig) - conf.servers = make([]string, 3)[0:0] // small, but the standard limit + conf.servers = make([]string, 0, 3) // small, but the standard limit conf.search = make([]string, 0) conf.ndots = 1 conf.timeout = 5 diff --git a/src/pkg/net/dnsconfig_unix_test.go b/src/pkg/net/dnsconfig_unix_test.go new file mode 100644 index 000000000..37ed4931d --- /dev/null +++ b/src/pkg/net/dnsconfig_unix_test.go @@ -0,0 +1,46 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd solaris + +package net + +import "testing" + +func TestDNSReadConfig(t *testing.T) { + dnsConfig, err := dnsReadConfig("testdata/resolv.conf") + if err != nil { + t.Fatal(err) + } + + if len(dnsConfig.servers) != 1 { + t.Errorf("len(dnsConfig.servers) = %d; want %d", len(dnsConfig.servers), 1) + } + if dnsConfig.servers[0] != "[192.168.1.1]" { + t.Errorf("dnsConfig.servers[0] = %s; want %s", dnsConfig.servers[0], "[192.168.1.1]") + } + + if len(dnsConfig.search) != 1 { + t.Errorf("len(dnsConfig.search) = %d; want %d", len(dnsConfig.search), 1) + } + if dnsConfig.search[0] != "Home" { + t.Errorf("dnsConfig.search[0] = %s; want %s", dnsConfig.search[0], "Home") + } + + if dnsConfig.ndots != 5 { + t.Errorf("dnsConfig.ndots = %d; want %d", dnsConfig.ndots, 5) + } + + if dnsConfig.timeout != 10 { + t.Errorf("dnsConfig.timeout = %d; want %d", dnsConfig.timeout, 10) + } + + if dnsConfig.attempts != 3 { + t.Errorf("dnsConfig.attempts = %d; want %d", dnsConfig.attempts, 3) + } + + if dnsConfig.rotate != true { + t.Errorf("dnsConfig.rotate = %t; want %t", dnsConfig.rotate, true) + } +} diff --git a/src/pkg/net/fd_mutex_test.go b/src/pkg/net/fd_mutex_test.go index 8383084b7..c34ec59b9 100644 --- a/src/pkg/net/fd_mutex_test.go +++ b/src/pkg/net/fd_mutex_test.go @@ -63,7 +63,8 @@ func TestMutexCloseUnblock(t *testing.T) { for i := 0; i < 4; i++ { go func() { if mu.RWLock(true) { - t.Fatal("broken") + t.Error("broken") + return } c <- true }() @@ -138,36 +139,44 @@ func TestMutexStress(t *testing.T) { switch r.Intn(3) { case 0: if !mu.Incref() { - t.Fatal("broken") + t.Error("broken") + return } if mu.Decref() { - t.Fatal("broken") + t.Error("broken") + return } case 1: if !mu.RWLock(true) { - t.Fatal("broken") + t.Error("broken") + return } // Ensure that it provides mutual exclusion for readers. if readState[0] != readState[1] { - t.Fatal("broken") + t.Error("broken") + return } readState[0]++ readState[1]++ if mu.RWUnlock(true) { - t.Fatal("broken") + t.Error("broken") + return } case 2: if !mu.RWLock(false) { - t.Fatal("broken") + t.Error("broken") + return } // Ensure that it provides mutual exclusion for writers. if writeState[0] != writeState[1] { - t.Fatal("broken") + t.Error("broken") + return } writeState[0]++ writeState[1]++ if mu.RWUnlock(false) { - t.Fatal("broken") + t.Error("broken") + return } } } diff --git a/src/pkg/net/fd_plan9.go b/src/pkg/net/fd_plan9.go index acc829402..5fe8effc2 100644 --- a/src/pkg/net/fd_plan9.go +++ b/src/pkg/net/fd_plan9.go @@ -13,12 +13,23 @@ import ( // Network file descritor. type netFD struct { - proto, name, dir string - ctl, data *os.File - laddr, raddr Addr + // locking/lifetime of sysfd + serialize access to Read and Write methods + fdmu fdMutex + + // immutable until Close + proto string + n string + dir string + ctl, data *os.File + laddr, raddr Addr } +var ( + netdir string // default network +) + func sysInit() { + netdir = "/net" } func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { @@ -27,16 +38,99 @@ func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline ti return dialChannel(net, ra, dialer, deadline) } -func newFD(proto, name string, ctl, data *os.File, laddr, raddr Addr) *netFD { - return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, data, laddr, raddr} +func newFD(proto, name string, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) { + return &netFD{proto: proto, n: name, dir: netdir + "/" + proto + "/" + name, ctl: ctl, data: data, laddr: laddr, raddr: raddr}, nil +} + +func (fd *netFD) init() error { + // stub for future fd.pd.Init(fd) + return nil +} + +func (fd *netFD) name() string { + var ls, rs string + if fd.laddr != nil { + ls = fd.laddr.String() + } + if fd.raddr != nil { + rs = fd.raddr.String() + } + return fd.proto + ":" + ls + "->" + rs } func (fd *netFD) ok() bool { return fd != nil && fd.ctl != nil } +func (fd *netFD) destroy() { + if !fd.ok() { + return + } + err := fd.ctl.Close() + if fd.data != nil { + if err1 := fd.data.Close(); err1 != nil && err == nil { + err = err1 + } + } + fd.ctl = nil + fd.data = nil +} + +// Add a reference to this fd. +// Returns an error if the fd cannot be used. +func (fd *netFD) incref() error { + if !fd.fdmu.Incref() { + return errClosing + } + return nil +} + +// Remove a reference to this FD and close if we've been asked to do so +// (and there are no references left). +func (fd *netFD) decref() { + if fd.fdmu.Decref() { + fd.destroy() + } +} + +// Add a reference to this fd and lock for reading. +// Returns an error if the fd cannot be used. +func (fd *netFD) readLock() error { + if !fd.fdmu.RWLock(true) { + return errClosing + } + return nil +} + +// Unlock for reading and remove a reference to this FD. +func (fd *netFD) readUnlock() { + if fd.fdmu.RWUnlock(true) { + fd.destroy() + } +} + +// Add a reference to this fd and lock for writing. +// Returns an error if the fd cannot be used. +func (fd *netFD) writeLock() error { + if !fd.fdmu.RWLock(false) { + return errClosing + } + return nil +} + +// Unlock for writing and remove a reference to this FD. +func (fd *netFD) writeUnlock() { + if fd.fdmu.RWUnlock(false) { + fd.destroy() + } +} + func (fd *netFD) Read(b []byte) (n int, err error) { if !fd.ok() || fd.data == nil { return 0, syscall.EINVAL } + if err := fd.readLock(); err != nil { + return 0, err + } + defer fd.readUnlock() n, err = fd.data.Read(b) if fd.proto == "udp" && err == io.EOF { n = 0 @@ -49,17 +143,21 @@ func (fd *netFD) Write(b []byte) (n int, err error) { if !fd.ok() || fd.data == nil { return 0, syscall.EINVAL } + if err := fd.writeLock(); err != nil { + return 0, err + } + defer fd.writeUnlock() return fd.data.Write(b) } -func (fd *netFD) CloseRead() error { +func (fd *netFD) closeRead() error { if !fd.ok() { return syscall.EINVAL } return syscall.EPLAN9 } -func (fd *netFD) CloseWrite() error { +func (fd *netFD) closeWrite() error { if !fd.ok() { return syscall.EINVAL } @@ -67,6 +165,9 @@ func (fd *netFD) CloseWrite() error { } func (fd *netFD) Close() error { + if !fd.fdmu.IncrefAndClose() { + return errClosing + } if !fd.ok() { return syscall.EINVAL } diff --git a/src/pkg/net/fd_poll_nacl.go b/src/pkg/net/fd_poll_nacl.go new file mode 100644 index 000000000..a3701f876 --- /dev/null +++ b/src/pkg/net/fd_poll_nacl.go @@ -0,0 +1,94 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "syscall" + "time" +) + +type pollDesc struct { + fd *netFD + closing bool +} + +func (pd *pollDesc) Init(fd *netFD) error { pd.fd = fd; return nil } + +func (pd *pollDesc) Close() {} + +func (pd *pollDesc) Lock() {} + +func (pd *pollDesc) Unlock() {} + +func (pd *pollDesc) Wakeup() {} + +func (pd *pollDesc) Evict() bool { + pd.closing = true + if pd.fd != nil { + syscall.StopIO(pd.fd.sysfd) + } + return false +} + +func (pd *pollDesc) Prepare(mode int) error { + if pd.closing { + return errClosing + } + return nil +} + +func (pd *pollDesc) PrepareRead() error { return pd.Prepare('r') } + +func (pd *pollDesc) PrepareWrite() error { return pd.Prepare('w') } + +func (pd *pollDesc) Wait(mode int) error { + if pd.closing { + return errClosing + } + return errTimeout +} + +func (pd *pollDesc) WaitRead() error { return pd.Wait('r') } + +func (pd *pollDesc) WaitWrite() error { return pd.Wait('w') } + +func (pd *pollDesc) WaitCanceled(mode int) {} + +func (pd *pollDesc) WaitCanceledRead() {} + +func (pd *pollDesc) WaitCanceledWrite() {} + +func (fd *netFD) setDeadline(t time.Time) error { + return setDeadlineImpl(fd, t, 'r'+'w') +} + +func (fd *netFD) setReadDeadline(t time.Time) error { + return setDeadlineImpl(fd, t, 'r') +} + +func (fd *netFD) setWriteDeadline(t time.Time) error { + return setDeadlineImpl(fd, t, 'w') +} + +func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { + d := t.UnixNano() + if t.IsZero() { + d = 0 + } + if err := fd.incref(); err != nil { + return err + } + switch mode { + case 'r': + syscall.SetReadDeadline(fd.sysfd, d) + case 'w': + syscall.SetWriteDeadline(fd.sysfd, d) + case 'r' + 'w': + syscall.SetReadDeadline(fd.sysfd, d) + syscall.SetWriteDeadline(fd.sysfd, d) + } + fd.decref() + return nil +} diff --git a/src/pkg/net/fd_poll_runtime.go b/src/pkg/net/fd_poll_runtime.go index e2b276886..2bddc836c 100644 --- a/src/pkg/net/fd_poll_runtime.go +++ b/src/pkg/net/fd_poll_runtime.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux netbsd openbsd windows solaris package net @@ -12,6 +12,9 @@ import ( "time" ) +// runtimeNano returns the current value of the runtime clock in nanoseconds. +func runtimeNano() int64 + func runtime_pollServerInit() func runtime_pollOpen(fd uintptr) (uintptr, int) func runtime_pollClose(ctx uintptr) @@ -128,7 +131,7 @@ func (fd *netFD) setWriteDeadline(t time.Time) error { } func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { - d := t.UnixNano() + d := runtimeNano() + int64(t.Sub(time.Now())) if t.IsZero() { d = 0 } diff --git a/src/pkg/net/fd_unix.go b/src/pkg/net/fd_unix.go index 9ed4f7536..b82ecd11c 100644 --- a/src/pkg/net/fd_unix.go +++ b/src/pkg/net/fd_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris package net @@ -75,19 +75,47 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr) error { if err := fd.pd.PrepareWrite(); err != nil { return err } + switch err := syscall.Connect(fd.sysfd, ra); err { + case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: + case nil, syscall.EISCONN: + return nil + case syscall.EINVAL: + // On Solaris we can see EINVAL if the socket has + // already been accepted and closed by the server. + // Treat this as a successful connection--writes to + // the socket will see EOF. For details and a test + // case in C see http://golang.org/issue/6828. + if runtime.GOOS == "solaris" { + return nil + } + fallthrough + default: + return err + } for { - err := syscall.Connect(fd.sysfd, ra) - if err == nil || err == syscall.EISCONN { - break + // Performing multiple connect system calls on a + // non-blocking socket under Unix variants does not + // necessarily result in earlier errors being + // returned. Instead, once runtime-integrated network + // poller tells us that the socket is ready, get the + // SO_ERROR socket option to see if the connection + // succeeded or failed. See issue 7474 for further + // details. + if err := fd.pd.WaitWrite(); err != nil { + return err } - if err != syscall.EINPROGRESS && err != syscall.EALREADY && err != syscall.EINTR { + nerr, err := syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) + if err != nil { return err } - if err = fd.pd.WaitWrite(); err != nil { + switch err := syscall.Errno(nerr); err { + case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: + case syscall.Errno(0), syscall.EISCONN: + return nil + default: return err } } - return nil } func (fd *netFD) destroy() { @@ -180,11 +208,11 @@ func (fd *netFD) shutdown(how int) error { return nil } -func (fd *netFD) CloseRead() error { +func (fd *netFD) closeRead() error { return fd.shutdown(syscall.SHUT_RD) } -func (fd *netFD) CloseWrite() error { +func (fd *netFD) closeWrite() error { return fd.shutdown(syscall.SHUT_WR) } @@ -215,7 +243,7 @@ func (fd *netFD) Read(p []byte) (n int, err error) { return } -func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { +func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { if err := fd.readLock(); err != nil { return 0, nil, err } @@ -242,7 +270,7 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { return } -func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { +func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { if err := fd.readLock(); err != nil { return 0, 0, 0, nil, err } @@ -313,7 +341,7 @@ func (fd *netFD) Write(p []byte) (nn int, err error) { return nn, err } -func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { +func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) { if err := fd.writeLock(); err != nil { return 0, err } @@ -338,7 +366,7 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { return } -func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { +func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { if err := fd.writeLock(); err != nil { return 0, 0, err } @@ -347,7 +375,7 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob return 0, 0, &OpError{"write", fd.net, fd.raddr, err} } for { - err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) + n, err = syscall.SendmsgN(fd.sysfd, p, oob, sa, 0) if err == syscall.EAGAIN { if err = fd.pd.WaitWrite(); err == nil { continue @@ -356,7 +384,6 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob break } if err == nil { - n = len(p) oobn = len(oob) } else { err = &OpError{"write", fd.net, fd.raddr, err} @@ -455,7 +482,6 @@ func dupCloseOnExecOld(fd int) (newfd int, err error) { func (fd *netFD) dup() (f *os.File, err error) { ns, err := dupCloseOnExec(fd.sysfd) if err != nil { - syscall.ForkLock.RUnlock() return nil, &OpError{"dup", fd.net, fd.laddr, err} } diff --git a/src/pkg/net/fd_unix_test.go b/src/pkg/net/fd_unix_test.go index 65d3e69a7..fe8e8ff6a 100644 --- a/src/pkg/net/fd_unix_test.go +++ b/src/pkg/net/fd_unix_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux netbsd openbsd solaris package net diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index 630fc5e6f..a1f6bc5f8 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -119,7 +119,7 @@ func (o *operation) InitBuf(buf []byte) { o.buf.Len = uint32(len(buf)) o.buf.Buf = nil if len(buf) != 0 { - o.buf.Buf = (*byte)(unsafe.Pointer(&buf[0])) + o.buf.Buf = &buf[0] } } @@ -431,11 +431,11 @@ func (fd *netFD) shutdown(how int) error { return nil } -func (fd *netFD) CloseRead() error { +func (fd *netFD) closeRead() error { return fd.shutdown(syscall.SHUT_RD) } -func (fd *netFD) CloseWrite() error { +func (fd *netFD) closeWrite() error { return fd.shutdown(syscall.SHUT_WR) } @@ -458,7 +458,7 @@ func (fd *netFD) Read(buf []byte) (int, error) { return n, err } -func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) { +func (fd *netFD) readFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) { if len(buf) == 0 { return 0, nil, nil } @@ -497,7 +497,7 @@ func (fd *netFD) Write(buf []byte) (int, error) { }) } -func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (int, error) { +func (fd *netFD) writeTo(buf []byte, sa syscall.Sockaddr) (int, error) { if len(buf) == 0 { return 0, nil } @@ -628,10 +628,10 @@ func (fd *netFD) dup() (*os.File, error) { var errNoSupport = errors.New("address family not supported") -func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { +func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { return 0, 0, 0, nil, errNoSupport } -func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { +func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { return 0, 0, errNoSupport } diff --git a/src/pkg/net/file_plan9.go b/src/pkg/net/file_plan9.go index f6ee1c29e..068f0881d 100644 --- a/src/pkg/net/file_plan9.go +++ b/src/pkg/net/file_plan9.go @@ -43,7 +43,7 @@ func newFileFD(f *os.File) (net *netFD, err error) { } comp := splitAtBytes(path, "/") n := len(comp) - if n < 3 || comp[0] != "net" { + if n < 3 || comp[0][0:3] != "net" { return nil, syscall.EPLAN9 } @@ -58,7 +58,7 @@ func newFileFD(f *os.File) (net *netFD, err error) { } defer close(fd) - dir := "/net/" + comp[n-2] + dir := netdir + "/" + comp[n-2] ctl = os.NewFile(uintptr(fd), dir+"/"+file) ctl.Seek(0, 0) var buf [16]byte @@ -71,19 +71,19 @@ func newFileFD(f *os.File) (net *netFD, err error) { if len(comp) < 4 { return nil, errors.New("could not find control file for connection") } - dir := "/net/" + comp[1] + "/" + name + dir := netdir + "/" + comp[1] + "/" + name ctl, err = os.OpenFile(dir+"/ctl", os.O_RDWR, 0) if err != nil { return nil, err } defer close(int(ctl.Fd())) } - dir := "/net/" + comp[1] + "/" + name + dir := netdir + "/" + comp[1] + "/" + name laddr, err := readPlan9Addr(comp[1], dir+"/local") if err != nil { return nil, err } - return newFD(comp[1], name, ctl, nil, laddr, nil), nil + return newFD(comp[1], name, ctl, nil, laddr, nil) } func newFileConn(f *os.File) (c Conn, err error) { diff --git a/src/pkg/net/file_test.go b/src/pkg/net/file_test.go index acaf18851..d81bca782 100644 --- a/src/pkg/net/file_test.go +++ b/src/pkg/net/file_test.go @@ -174,12 +174,14 @@ var filePacketConnTests = []struct { {net: "udp6", addr: "[::1]", ipv6: true}, + {net: "ip4:icmp", addr: "127.0.0.1"}, + {net: "unixgram", addr: "@gotest3/net", linux: true}, } func TestFilePacketConn(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "nacl", "plan9", "windows": t.Skipf("skipping test on %q", runtime.GOOS) } @@ -187,6 +189,10 @@ func TestFilePacketConn(t *testing.T) { if skipServerTest(tt.net, "unixgram", tt.addr, tt.ipv6, false, tt.linux) { continue } + if os.Getuid() != 0 && tt.net == "ip4:icmp" { + t.Log("skipping test; must be root") + continue + } testFilePacketConnListen(t, tt.net, tt.addr) switch tt.addr { case "", "0.0.0.0", "[::ffff:0.0.0.0]", "[::]": diff --git a/src/pkg/net/file_unix.go b/src/pkg/net/file_unix.go index 8fe1b0eb0..07b3ecf62 100644 --- a/src/pkg/net/file_unix.go +++ b/src/pkg/net/file_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris package net @@ -129,6 +129,8 @@ func FilePacketConn(f *os.File) (c PacketConn, err error) { switch fd.laddr.(type) { case *UDPAddr: return newUDPConn(fd), nil + case *IPAddr: + return newIPConn(fd), nil case *UnixAddr: return newUnixConn(fd), nil } diff --git a/src/pkg/net/hosts_test.go b/src/pkg/net/hosts_test.go index b07ed0baa..2fe358e07 100644 --- a/src/pkg/net/hosts_test.go +++ b/src/pkg/net/hosts_test.go @@ -41,7 +41,7 @@ func TestLookupStaticHost(t *testing.T) { if len(ips) != len(tt.ips) { t.Errorf("# of hosts = %v; want %v", len(ips), len(tt.ips)) - return + continue } for k, v := range ips { if tt.ips[k].String() != v { diff --git a/src/pkg/net/http/cgi/host.go b/src/pkg/net/http/cgi/host.go index d27cc4dc9..ec95a972c 100644 --- a/src/pkg/net/http/cgi/host.go +++ b/src/pkg/net/http/cgi/host.go @@ -214,12 +214,17 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { internalError(err) return } + if hook := testHookStartProcess; hook != nil { + hook(cmd.Process) + } defer cmd.Wait() defer stdoutRead.Close() linebody := bufio.NewReaderSize(stdoutRead, 1024) headers := make(http.Header) statusCode := 0 + headerLines := 0 + sawBlankLine := false for { line, isPrefix, err := linebody.ReadLine() if isPrefix { @@ -236,8 +241,10 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } if len(line) == 0 { + sawBlankLine = true break } + headerLines++ parts := strings.SplitN(string(line), ":", 2) if len(parts) < 2 { h.printf("cgi: bogus header line: %s", string(line)) @@ -263,6 +270,11 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { headers.Add(header, val) } } + if headerLines == 0 || !sawBlankLine { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: no headers") + return + } if loc := headers.Get("Location"); loc != "" { if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil { @@ -274,6 +286,12 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } + if statusCode == 0 && headers.Get("Content-Type") == "" { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: missing required Content-Type in headers") + return + } + if statusCode == 0 { statusCode = http.StatusOK } @@ -292,6 +310,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { _, err = io.Copy(rw, linebody) if err != nil { h.printf("cgi: copy error: %v", err) + // And kill the child CGI process so we don't hang on + // the deferred cmd.Wait above if the error was just + // the client (rw) going away. If it was a read error + // (because the child died itself), then the extra + // kill of an already-dead process is harmless (the PID + // won't be reused until the Wait above). + cmd.Process.Kill() } } @@ -348,3 +373,5 @@ func upperCaseAndUnderscore(r rune) rune { // TODO: other transformations in spec or practice? return r } + +var testHookStartProcess func(*os.Process) // nil except for some tests diff --git a/src/pkg/net/http/cgi/matryoshka_test.go b/src/pkg/net/http/cgi/matryoshka_test.go index e1a78c8f6..18c4803e7 100644 --- a/src/pkg/net/http/cgi/matryoshka_test.go +++ b/src/pkg/net/http/cgi/matryoshka_test.go @@ -9,15 +9,25 @@ package cgi import ( + "bytes" + "errors" "fmt" + "io" "net/http" + "net/http/httptest" "os" + "runtime" "testing" + "time" ) // This test is a CGI host (testing host.go) that runs its own binary // as a child process testing the other half of CGI (child.go). func TestHostingOurselves(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl") + } + h := &Handler{ Path: os.Args[0], Root: "/test.go", @@ -51,8 +61,88 @@ func TestHostingOurselves(t *testing.T) { } } -// Test that a child handler only writing headers works. +type customWriterRecorder struct { + w io.Writer + *httptest.ResponseRecorder +} + +func (r *customWriterRecorder) Write(p []byte) (n int, err error) { + return r.w.Write(p) +} + +type limitWriter struct { + w io.Writer + n int +} + +func (w *limitWriter) Write(p []byte) (n int, err error) { + if len(p) > w.n { + p = p[:w.n] + } + if len(p) > 0 { + n, err = w.w.Write(p) + w.n -= n + } + if w.n == 0 { + err = errors.New("past write limit") + } + return +} + +// If there's an error copying the child's output to the parent, test +// that we kill the child. +func TestKillChildAfterCopyError(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl") + } + + defer func() { testHookStartProcess = nil }() + proc := make(chan *os.Process, 1) + testHookStartProcess = func(p *os.Process) { + proc <- p + } + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil) + rec := httptest.NewRecorder() + var out bytes.Buffer + const writeLen = 50 << 10 + rw := &customWriterRecorder{&limitWriter{&out, writeLen}, rec} + + donec := make(chan bool, 1) + go func() { + h.ServeHTTP(rw, req) + donec <- true + }() + + select { + case <-donec: + if out.Len() != writeLen || out.Bytes()[0] != 'a' { + t.Errorf("unexpected output: %q", out.Bytes()) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout. ServeHTTP hung and didn't kill the child process?") + select { + case p := <-proc: + p.Kill() + t.Logf("killed process") + default: + t.Logf("didn't kill process") + } + } +} + +// Test that a child handler writing only headers works. +// golang.org/issue/7196 func TestChildOnlyHeaders(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl") + } + h := &Handler{ Path: os.Args[0], Root: "/test.go", @@ -67,18 +157,63 @@ func TestChildOnlyHeaders(t *testing.T) { } } +// golang.org/issue/7198 +func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") } +func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") } +func Test500WithEmptyHeaders(t *testing.T) { want500Test(t, "/empty-headers") } + +func want500Test(t *testing.T, path string) { + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "_body": "", + } + replay := runCgiTest(t, h, "GET "+path+" HTTP/1.0\nHost: example.com\n\n", expectedMap) + if replay.Code != 500 { + t.Errorf("Got code %d; want 500", replay.Code) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + // Note: not actually a test. func TestBeChildCGIProcess(t *testing.T) { if os.Getenv("REQUEST_METHOD") == "" { // Not in a CGI environment; skipping test. return } + switch os.Getenv("REQUEST_URI") { + case "/immediate-disconnect": + os.Exit(0) + case "/no-content-type": + fmt.Printf("Content-Length: 6\n\nHello\n") + os.Exit(0) + case "/empty-headers": + fmt.Printf("\nHello") + os.Exit(0) + } Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("X-Test-Header", "X-Test-Value") req.ParseForm() if req.FormValue("no-body") == "1" { return } + if req.FormValue("write-forever") == "1" { + io.Copy(rw, neverEnding('a')) + for { + time.Sleep(5 * time.Second) // hang forever, until killed + } + } fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n") for k, vv := range req.Form { for _, v := range vv { diff --git a/src/pkg/net/http/chunked.go b/src/pkg/net/http/chunked.go index 91db01724..749f29d32 100644 --- a/src/pkg/net/http/chunked.go +++ b/src/pkg/net/http/chunked.go @@ -4,13 +4,14 @@ // The wire protocol for HTTP's "chunked" Transfer-Encoding. -// This code is duplicated in httputil/chunked.go. +// This code is duplicated in net/http and net/http/httputil. // Please make any changes in both files. package http import ( "bufio" + "bytes" "errors" "fmt" "io" @@ -57,26 +58,45 @@ func (cr *chunkedReader) beginChunk() { } } -func (cr *chunkedReader) Read(b []uint8) (n int, err error) { - if cr.err != nil { - return 0, cr.err +func (cr *chunkedReader) chunkHeaderAvailable() bool { + n := cr.r.Buffered() + if n > 0 { + peek, _ := cr.r.Peek(n) + return bytes.IndexByte(peek, '\n') >= 0 } - if cr.n == 0 { - cr.beginChunk() - if cr.err != nil { - return 0, cr.err + return false +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + for cr.err == nil { + if cr.n == 0 { + if n > 0 && !cr.chunkHeaderAvailable() { + // We've read enough. Don't potentially block + // reading a new chunk header. + break + } + cr.beginChunk() + continue } - } - if uint64(len(b)) > cr.n { - b = b[0:cr.n] - } - n, cr.err = cr.r.Read(b) - cr.n -= uint64(n) - if cr.n == 0 && cr.err == nil { - // end of chunk (CRLF) - if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil { - if cr.buf[0] != '\r' || cr.buf[1] != '\n' { - cr.err = errors.New("malformed chunked encoding") + if len(b) == 0 { + break + } + rbuf := b + if uint64(len(rbuf)) > cr.n { + rbuf = rbuf[:cr.n] + } + var n0 int + n0, cr.err = cr.r.Read(rbuf) + n += n0 + b = b[n0:] + cr.n -= uint64(n0) + // If we're at the end of a chunk, read the next two + // bytes to verify they are "\r\n". + if cr.n == 0 && cr.err == nil { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { + cr.err = errors.New("malformed chunked encoding") + } } } } diff --git a/src/pkg/net/http/chunked_test.go b/src/pkg/net/http/chunked_test.go index 0b18c7b55..34544790a 100644 --- a/src/pkg/net/http/chunked_test.go +++ b/src/pkg/net/http/chunked_test.go @@ -2,17 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This code is duplicated in httputil/chunked_test.go. +// This code is duplicated in net/http and net/http/httputil. // Please make any changes in both files. package http import ( + "bufio" "bytes" "fmt" "io" "io/ioutil" - "runtime" + "strings" "testing" ) @@ -41,9 +42,77 @@ func TestChunk(t *testing.T) { } } +func TestChunkReadMultiple(t *testing.T) { + // Bunch of small chunks, all read together. + { + var b bytes.Buffer + w := newChunkedWriter(&b) + w.Write([]byte("foo")) + w.Write([]byte("bar")) + w.Close() + + r := newChunkedReader(&b) + buf := make([]byte, 10) + n, err := r.Read(buf) + if n != 6 || err != io.EOF { + t.Errorf("Read = %d, %v; want 6, EOF", n, err) + } + buf = buf[:n] + if string(buf) != "foobar" { + t.Errorf("Read = %q; want %q", buf, "foobar") + } + } + + // One big chunk followed by a little chunk, but the small bufio.Reader size + // should prevent the second chunk header from being read. + { + var b bytes.Buffer + w := newChunkedWriter(&b) + // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes, + // the same as the bufio ReaderSize below (the minimum), so even + // though we're going to try to Read with a buffer larger enough to also + // receive "foo", the second chunk header won't be read yet. + const fillBufChunk = "0123456789a" + const shortChunk = "foo" + w.Write([]byte(fillBufChunk)) + w.Write([]byte(shortChunk)) + w.Close() + + r := newChunkedReader(bufio.NewReaderSize(&b, 16)) + buf := make([]byte, len(fillBufChunk)+len(shortChunk)) + n, err := r.Read(buf) + if n != len(fillBufChunk) || err != nil { + t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk)) + } + buf = buf[:n] + if string(buf) != fillBufChunk { + t.Errorf("Read = %q; want %q", buf, fillBufChunk) + } + + n, err = r.Read(buf) + if n != len(shortChunk) || err != io.EOF { + t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk)) + } + } + + // And test that we see an EOF chunk, even though our buffer is already full: + { + r := newChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n"))) + buf := make([]byte, 3) + n, err := r.Read(buf) + if n != 3 || err != io.EOF { + t.Errorf("Read = %d, %v; want 3, EOF", n, err) + } + if string(buf) != "foo" { + t.Errorf("buf = %q; want foo", buf) + } + } +} + func TestChunkReaderAllocs(t *testing.T) { - // temporarily set GOMAXPROCS to 1 as we are testing memory allocations - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + if testing.Short() { + t.Skip("skipping in short mode") + } var buf bytes.Buffer w := newChunkedWriter(&buf) a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") @@ -52,26 +121,23 @@ func TestChunkReaderAllocs(t *testing.T) { w.Write(c) w.Close() - r := newChunkedReader(&buf) readBuf := make([]byte, len(a)+len(b)+len(c)+1) - - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - m0 := ms.Mallocs - - n, err := io.ReadFull(r, readBuf) - - runtime.ReadMemStats(&ms) - mallocs := ms.Mallocs - m0 - if mallocs > 1 { - t.Errorf("%d mallocs; want <= 1", mallocs) - } - - if n != len(readBuf)-1 { - t.Errorf("read %d bytes; want %d", n, len(readBuf)-1) - } - if err != io.ErrUnexpectedEOF { - t.Errorf("read error = %v; want ErrUnexpectedEOF", err) + byter := bytes.NewReader(buf.Bytes()) + bufr := bufio.NewReader(byter) + mallocs := testing.AllocsPerRun(100, func() { + byter.Seek(0, 0) + bufr.Reset(byter) + r := newChunkedReader(bufr) + n, err := io.ReadFull(r, readBuf) + if n != len(readBuf)-1 { + t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Fatalf("read error = %v; want ErrUnexpectedEOF", err) + } + }) + if mallocs > 1.5 { + t.Errorf("mallocs = %v; want 1", mallocs) } } diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go index 22f2e865c..a5a3abe61 100644 --- a/src/pkg/net/http/client.go +++ b/src/pkg/net/http/client.go @@ -14,9 +14,12 @@ import ( "errors" "fmt" "io" + "io/ioutil" "log" "net/url" "strings" + "sync" + "time" ) // A Client is an HTTP client. Its zero value (DefaultClient) is a @@ -52,6 +55,20 @@ type Client struct { // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar CookieJar + + // Timeout specifies a time limit for requests made by this + // Client. The timeout includes connection time, any + // redirects, and reading the response body. The timer remains + // running after Get, Head, Post, or Do return and will + // interrupt reading of the Response.Body. + // + // A Timeout of zero means no timeout. + // + // The Client's Transport must support the CancelRequest + // method or Client will return errors when attempting to make + // a request with Get, Head, Post, or Do. Client's default + // Transport (DefaultTransport) supports CancelRequest. + Timeout time.Duration } // DefaultClient is the default Client and is used by Get, Head, and Post. @@ -74,8 +91,9 @@ type RoundTripper interface { // authentication, or cookies. // // RoundTrip should not modify the request, except for - // consuming and closing the Body. The request's URL and - // Header fields are guaranteed to be initialized. + // consuming and closing the Body, including on errors. The + // request's URL and Header fields are guaranteed to be + // initialized. RoundTrip(*Request) (*Response, error) } @@ -97,7 +115,7 @@ func (c *Client) send(req *Request) (*Response, error) { req.AddCookie(cookie) } } - resp, err := send(req, c.Transport) + resp, err := send(req, c.transport()) if err != nil { return nil, err } @@ -123,6 +141,9 @@ func (c *Client) send(req *Request) (*Response, error) { // (typically Transport) may not be able to re-use a persistent TCP // connection to the server for a subsequent "keep-alive" request. // +// The request Body, if non-nil, will be closed by the underlying +// Transport, even on errors. +// // Generally Get, Post, or PostForm will be used instead of Do. func (c *Client) Do(req *Request) (resp *Response, err error) { if req.Method == "GET" || req.Method == "HEAD" { @@ -134,22 +155,28 @@ func (c *Client) Do(req *Request) (resp *Response, err error) { return c.send(req) } +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport + } + return DefaultTransport +} + // send issues an HTTP request. // Caller should close resp.Body when done reading from it. func send(req *Request, t RoundTripper) (resp *Response, err error) { if t == nil { - t = DefaultTransport - if t == nil { - err = errors.New("http: no Client.Transport or DefaultTransport") - return - } + req.closeBody() + return nil, errors.New("http: no Client.Transport or DefaultTransport") } if req.URL == nil { + req.closeBody() return nil, errors.New("http: nil Request.URL") } if req.RequestURI != "" { + req.closeBody() return nil, errors.New("http: Request.RequestURI can't be set in client requests.") } @@ -257,21 +284,40 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo var via []*Request if ireq.URL == nil { + ireq.closeBody() return nil, errors.New("http: nil Request.URL") } + var reqmu sync.Mutex // guards req req := ireq + + var timer *time.Timer + if c.Timeout > 0 { + type canceler interface { + CancelRequest(*Request) + } + tr, ok := c.transport().(canceler) + if !ok { + return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport()) + } + timer = time.AfterFunc(c.Timeout, func() { + reqmu.Lock() + defer reqmu.Unlock() + tr.CancelRequest(req) + }) + } + urlStr := "" // next relative or absolute URL to fetch (after first request) redirectFailed := false for redirect := 0; ; redirect++ { if redirect != 0 { - req = new(Request) - req.Method = ireq.Method + nreq := new(Request) + nreq.Method = ireq.Method if ireq.Method == "POST" || ireq.Method == "PUT" { - req.Method = "GET" + nreq.Method = "GET" } - req.Header = make(Header) - req.URL, err = base.Parse(urlStr) + nreq.Header = make(Header) + nreq.URL, err = base.Parse(urlStr) if err != nil { break } @@ -279,15 +325,18 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo // Add the Referer header. lastReq := via[len(via)-1] if lastReq.URL.Scheme != "https" { - req.Header.Set("Referer", lastReq.URL.String()) + nreq.Header.Set("Referer", lastReq.URL.String()) } - err = redirectChecker(req, via) + err = redirectChecker(nreq, via) if err != nil { redirectFailed = true break } } + reqmu.Lock() + req = nreq + reqmu.Unlock() } urlStr = req.URL.String() @@ -296,6 +345,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo } if shouldRedirect(resp.StatusCode) { + // Read the body if small so underlying TCP connection will be re-used. + // No need to check for errors: if it fails, Transport won't reuse it anyway. + const maxBodySlurpSize = 2 << 10 + if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { + io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) + } resp.Body.Close() if urlStr = resp.Header.Get("Location"); urlStr == "" { err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode)) @@ -305,7 +360,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo via = append(via, req) continue } - return + if timer != nil { + resp.Body = &cancelTimerBody{timer, resp.Body} + } + return resp, nil } method := ireq.Method @@ -349,7 +407,7 @@ func Post(url string, bodyType string, body io.Reader) (resp *Response, err erro // Caller should close resp.Body when done reading from it. // // If the provided body is also an io.Closer, it is closed after the -// body is successfully written to the server. +// request. func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { req, err := NewRequest("POST", url, body) if err != nil { @@ -408,3 +466,22 @@ func (c *Client) Head(url string) (resp *Response, err error) { } return c.doFollowingRedirects(req, shouldRedirectGet) } + +type cancelTimerBody struct { + t *time.Timer + rc io.ReadCloser +} + +func (b *cancelTimerBody) Read(p []byte) (n int, err error) { + n, err = b.rc.Read(p) + if err == io.EOF { + b.t.Stop() + } + return +} + +func (b *cancelTimerBody) Close() error { + err := b.rc.Close() + b.t.Stop() + return err +} diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go index 997d04151..6392c1baf 100644 --- a/src/pkg/net/http/client_test.go +++ b/src/pkg/net/http/client_test.go @@ -15,14 +15,18 @@ import ( "fmt" "io" "io/ioutil" + "log" "net" . "net/http" "net/http/httptest" "net/url" + "reflect" + "sort" "strconv" "strings" "sync" "testing" + "time" ) var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) { @@ -54,6 +58,13 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { } } +type chanWriter chan string + +func (w chanWriter) Write(p []byte) (n int, err error) { + w <- string(p) + return len(p), nil +} + func TestClient(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) @@ -373,24 +384,6 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { return j.perURL[u.Host] } -func TestRedirectCookiesOnRequest(t *testing.T) { - defer afterTest(t) - var ts *httptest.Server - ts = httptest.NewServer(echoCookiesRedirectHandler) - defer ts.Close() - c := &Client{} - req, _ := NewRequest("GET", ts.URL, nil) - req.AddCookie(expectedCookies[0]) - // TODO: Uncomment when an implementation of a RFC6265 cookie jar lands. - _ = c - // resp, _ := c.Do(req) - // matchReturnedCookies(t, expectedCookies, resp.Cookies()) - - req, _ = NewRequest("GET", ts.URL, nil) - // resp, _ = c.Do(req) - // matchReturnedCookies(t, expectedCookies[1:], resp.Cookies()) -} - func TestRedirectCookiesJar(t *testing.T) { defer afterTest(t) var ts *httptest.Server @@ -410,8 +403,8 @@ func TestRedirectCookiesJar(t *testing.T) { } func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { - t.Logf("Received cookies: %v", given) if len(given) != len(expected) { + t.Logf("Received cookies: %v", given) t.Errorf("Expected %d cookies, got %d", len(expected), len(given)) } for _, ec := range expected { @@ -582,6 +575,8 @@ func TestClientInsecureTransport(t *testing.T) { ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) })) + errc := make(chanWriter, 10) // but only expecting 1 + ts.Config.ErrorLog = log.New(errc, "", 0) defer ts.Close() // TODO(bradfitz): add tests for skipping hostname checks too? @@ -603,6 +598,16 @@ func TestClientInsecureTransport(t *testing.T) { res.Body.Close() } } + + select { + case v := <-errc: + if !strings.Contains(v, "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } + } func TestClientErrorWithRequestURI(t *testing.T) { @@ -653,6 +658,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() + errc := make(chanWriter, 10) // but only expecting 1 + ts.Config.ErrorLog = log.New(errc, "", 0) trans := newTLSTransport(t, ts) trans.TLSClientConfig.ServerName = "badserver" @@ -664,6 +671,14 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) } + select { + case v := <-errc: + if !strings.Contains(v, "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } } // Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName @@ -696,6 +711,33 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) { res.Body.Close() } +func TestResponseSetsTLSConnectionState(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + defer ts.Close() + + tr := newTLSTransport(t, ts) + tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA} + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get("https://example.com/") + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.TLS == nil { + t.Fatal("Response didn't set TLS Connection State.") + } + if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want { + t.Errorf("TLS Cipher Suite = %d; want %d", got, want) + } +} + // Verify Response.ContentLength is populated. http://golang.org/issue/4126 func TestClientHeadContentLength(t *testing.T) { defer afterTest(t) @@ -799,3 +841,198 @@ func TestBasicAuth(t *testing.T) { t.Errorf("Invalid auth %q", auth) } } + +func TestClientTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer afterTest(t) + sawRoot := make(chan bool, 1) + sawSlow := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/" { + sawRoot <- true + Redirect(w, r, "/slow", StatusFound) + return + } + if r.URL.Path == "/slow" { + w.Write([]byte("Hello")) + w.(Flusher).Flush() + sawSlow <- true + time.Sleep(2 * time.Second) + return + } + })) + defer ts.Close() + const timeout = 500 * time.Millisecond + c := &Client{ + Timeout: timeout, + } + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + select { + case <-sawRoot: + // good. + default: + t.Fatal("handler never got / request") + } + + select { + case <-sawSlow: + // good. + default: + t.Fatal("handler never got /slow request") + } + + errc := make(chan error, 1) + go func() { + _, err := ioutil.ReadAll(res.Body) + errc <- err + res.Body.Close() + }() + + const failTime = timeout * 2 + select { + case err := <-errc: + if err == nil { + t.Error("expected error from ReadAll") + } + // Expected error. + case <-time.After(failTime): + t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout) + } +} + +func TestClientRedirectEatsBody(t *testing.T) { + defer afterTest(t) + saw := make(chan string, 2) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + saw <- r.RemoteAddr + if r.URL.Path == "/" { + Redirect(w, r, "/foo", StatusFound) // which includes a body + } + })) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + var first string + select { + case first = <-saw: + default: + t.Fatal("server didn't see a request") + } + + var second string + select { + case second = <-saw: + default: + t.Fatal("server didn't see a second request") + } + + if first != second { + t.Fatal("server saw different client ports before & after the redirect") + } +} + +// eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF. +type eofReaderFunc func() + +func (f eofReaderFunc) Read(p []byte) (n int, err error) { + f() + return 0, io.EOF +} + +func TestClientTrailers(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") + w.Header().Add("Trailer", "Server-Trailer-C") + + var decl []string + for k := range r.Trailer { + decl = append(decl, k) + } + sort.Strings(decl) + + slurp, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("Server reading request body: %v", err) + } + if string(slurp) != "foo" { + t.Errorf("Server read request body %q; want foo", slurp) + } + if r.Trailer == nil { + io.WriteString(w, "nil Trailer") + } else { + fmt.Fprintf(w, "decl: %v, vals: %s, %s", + decl, + r.Trailer.Get("Client-Trailer-A"), + r.Trailer.Get("Client-Trailer-B")) + } + + // TODO: golang.org/issue/7759: there's no way yet for + // the server to set trailers without hijacking, so do + // that for now, just to test the client. Later, in + // Go 1.4, it should be implicit that any mutations + // to w.Header() after the initial write are the + // trailers to be sent, if and only if they were + // previously declared with w.Header().Set("Trailer", + // ..keys..) + w.(Flusher).Flush() + conn, buf, _ := w.(Hijacker).Hijack() + t := Header{} + t.Set("Server-Trailer-A", "valuea") + t.Set("Server-Trailer-C", "valuec") // skipping B + buf.WriteString("0\r\n") // eof + t.Write(buf) + buf.WriteString("\r\n") // end of trailers + buf.Flush() + conn.Close() + })) + defer ts.Close() + + var req *Request + req, _ = NewRequest("POST", ts.URL, io.MultiReader( + eofReaderFunc(func() { + req.Trailer["Client-Trailer-A"] = []string{"valuea"} + }), + strings.NewReader("foo"), + eofReaderFunc(func() { + req.Trailer["Client-Trailer-B"] = []string{"valueb"} + }), + )) + req.Trailer = Header{ + "Client-Trailer-A": nil, // to be set later + "Client-Trailer-B": nil, // to be set later + } + req.ContentLength = -1 + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { + t.Error(err) + } + want := Header{ + "Server-Trailer-A": []string{"valuea"}, + "Server-Trailer-B": nil, + "Server-Trailer-C": []string{"valuec"}, + } + if !reflect.DeepEqual(res.Trailer, want) { + t.Errorf("Response trailers = %#v; want %#v", res.Trailer, want) + } +} diff --git a/src/pkg/net/http/cookie.go b/src/pkg/net/http/cookie.go index 8b01c508e..dc60ba87f 100644 --- a/src/pkg/net/http/cookie.go +++ b/src/pkg/net/http/cookie.go @@ -76,11 +76,7 @@ func readSetCookies(h Header) []*Cookie { attr, val = attr[:j], attr[j+1:] } lowerAttr := strings.ToLower(attr) - parseCookieValueFn := parseCookieValue - if lowerAttr == "expires" { - parseCookieValueFn = parseCookieExpiresValue - } - val, success = parseCookieValueFn(val) + val, success = parseCookieValue(val) if !success { c.Unparsed = append(c.Unparsed, parts[i]) continue @@ -94,7 +90,6 @@ func readSetCookies(h Header) []*Cookie { continue case "domain": c.Domain = val - // TODO: Add domain parsing continue case "max-age": secs, err := strconv.Atoi(val) @@ -121,7 +116,6 @@ func readSetCookies(h Header) []*Cookie { continue case "path": c.Path = val - // TODO: Add path parsing continue } c.Unparsed = append(c.Unparsed, parts[i]) @@ -300,12 +294,23 @@ func sanitizeCookieName(n string) string { // ; US-ASCII characters excluding CTLs, // ; whitespace DQUOTE, comma, semicolon, // ; and backslash +// We loosen this as spaces and commas are common in cookie values +// but we produce a quoted cookie-value in when value starts or ends +// with a comma or space. +// See http://golang.org/issue/7243 for the discussion. func sanitizeCookieValue(v string) string { - return sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) + v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) + if len(v) == 0 { + return v + } + if v[0] == ' ' || v[0] == ',' || v[len(v)-1] == ' ' || v[len(v)-1] == ',' { + return `"` + v + `"` + } + return v } func validCookieValueByte(b byte) bool { - return 0x20 < b && b < 0x7f && b != '"' && b != ',' && b != ';' && b != '\\' + return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\' } // path-av = "Path=" path-value @@ -340,38 +345,13 @@ func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string { return string(buf) } -func unquoteCookieValue(v string) string { - if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' { - return v[1 : len(v)-1] - } - return v -} - -func isCookieByte(c byte) bool { - switch { - case c == 0x21, 0x23 <= c && c <= 0x2b, 0x2d <= c && c <= 0x3a, - 0x3c <= c && c <= 0x5b, 0x5d <= c && c <= 0x7e: - return true - } - return false -} - -func isCookieExpiresByte(c byte) (ok bool) { - return isCookieByte(c) || c == ',' || c == ' ' -} - func parseCookieValue(raw string) (string, bool) { - return parseCookieValueUsing(raw, isCookieByte) -} - -func parseCookieExpiresValue(raw string) (string, bool) { - return parseCookieValueUsing(raw, isCookieExpiresByte) -} - -func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool) { - raw = unquoteCookieValue(raw) + // Strip the quotes, if present. + if len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { + raw = raw[1 : len(raw)-1] + } for i := 0; i < len(raw); i++ { - if !validByte(raw[i]) { + if !validCookieValueByte(raw[i]) { return "", false } } diff --git a/src/pkg/net/http/cookie_test.go b/src/pkg/net/http/cookie_test.go index 11b01cc57..f78f37299 100644 --- a/src/pkg/net/http/cookie_test.go +++ b/src/pkg/net/http/cookie_test.go @@ -5,9 +5,13 @@ package http import ( + "bytes" "encoding/json" "fmt" + "log" + "os" "reflect" + "strings" "testing" "time" ) @@ -48,15 +52,61 @@ var writeSetCookiesTests = []struct { &Cookie{Name: "cookie-8", Value: "eight", Domain: "::1"}, "cookie-8=eight", }, + // The "special" cookies have values containing commas or spaces which + // are disallowed by RFC 6265 but are common in the wild. + { + &Cookie{Name: "special-1", Value: "a z"}, + `special-1=a z`, + }, + { + &Cookie{Name: "special-2", Value: " z"}, + `special-2=" z"`, + }, + { + &Cookie{Name: "special-3", Value: "a "}, + `special-3="a "`, + }, + { + &Cookie{Name: "special-4", Value: " "}, + `special-4=" "`, + }, + { + &Cookie{Name: "special-5", Value: "a,z"}, + `special-5=a,z`, + }, + { + &Cookie{Name: "special-6", Value: ",z"}, + `special-6=",z"`, + }, + { + &Cookie{Name: "special-7", Value: "a,"}, + `special-7="a,"`, + }, + { + &Cookie{Name: "special-8", Value: ","}, + `special-8=","`, + }, + { + &Cookie{Name: "empty-value", Value: ""}, + `empty-value=`, + }, } func TestWriteSetCookies(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + for i, tt := range writeSetCookiesTests { if g, e := tt.Cookie.String(), tt.Raw; g != e { t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g) continue } } + + if got, sub := logbuf.String(), "dropping domain attribute"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } } type headerOnlyResponseWriter Header @@ -166,6 +216,40 @@ var readSetCookiesTests = []struct { Raw: "ASP.NET_SessionId=foo; path=/; HttpOnly", }}, }, + // Make sure we can properly read back the Set-Cookie headers we create + // for values containing spaces or commas: + { + Header{"Set-Cookie": {`special-1=a z`}}, + []*Cookie{{Name: "special-1", Value: "a z", Raw: `special-1=a z`}}, + }, + { + Header{"Set-Cookie": {`special-2=" z"`}}, + []*Cookie{{Name: "special-2", Value: " z", Raw: `special-2=" z"`}}, + }, + { + Header{"Set-Cookie": {`special-3="a "`}}, + []*Cookie{{Name: "special-3", Value: "a ", Raw: `special-3="a "`}}, + }, + { + Header{"Set-Cookie": {`special-4=" "`}}, + []*Cookie{{Name: "special-4", Value: " ", Raw: `special-4=" "`}}, + }, + { + Header{"Set-Cookie": {`special-5=a,z`}}, + []*Cookie{{Name: "special-5", Value: "a,z", Raw: `special-5=a,z`}}, + }, + { + Header{"Set-Cookie": {`special-6=",z"`}}, + []*Cookie{{Name: "special-6", Value: ",z", Raw: `special-6=",z"`}}, + }, + { + Header{"Set-Cookie": {`special-7=a,`}}, + []*Cookie{{Name: "special-7", Value: "a,", Raw: `special-7=a,`}}, + }, + { + Header{"Set-Cookie": {`special-8=","`}}, + []*Cookie{{Name: "special-8", Value: ",", Raw: `special-8=","`}}, + }, // TODO(bradfitz): users have reported seeing this in the // wild, but do browsers handle it? RFC 6265 just says "don't @@ -244,22 +328,39 @@ func TestReadCookies(t *testing.T) { } func TestCookieSanitizeValue(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + tests := []struct { in, want string }{ {"foo", "foo"}, - {"foo bar", "foobar"}, + {"foo;bar", "foobar"}, + {"foo\\bar", "foobar"}, + {"foo\"bar", "foobar"}, {"\x00\x7e\x7f\x80", "\x7e"}, {`"withquotes"`, "withquotes"}, + {"a z", "a z"}, + {" z", `" z"`}, + {"a ", `"a "`}, } for _, tt := range tests { if got := sanitizeCookieValue(tt.in); got != tt.want { t.Errorf("sanitizeCookieValue(%q) = %q; want %q", tt.in, got, tt.want) } } + + if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } } func TestCookieSanitizePath(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + tests := []struct { in, want string }{ @@ -272,4 +373,8 @@ func TestCookieSanitizePath(t *testing.T) { t.Errorf("sanitizeCookiePath(%q) = %q; want %q", tt.in, got, tt.want) } } + + if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } } diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go index 22b7f2796..960563b24 100644 --- a/src/pkg/net/http/export_test.go +++ b/src/pkg/net/http/export_test.go @@ -21,7 +21,7 @@ var ExportAppendTime = appendTime func (t *Transport) NumPendingRequestsForTesting() int { t.reqMu.Lock() defer t.reqMu.Unlock() - return len(t.reqConn) + return len(t.reqCanceler) } func (t *Transport) IdleConnKeysForTesting() (keys []string) { @@ -32,7 +32,7 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { return } for key := range t.idleConn { - keys = append(keys, key) + keys = append(keys, key.String()) } return } @@ -43,11 +43,12 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int { if t.idleConn == nil { return 0 } - conns, ok := t.idleConn[cacheKey] - if !ok { - return 0 + for k, conns := range t.idleConn { + if k.String() == cacheKey { + return len(conns) + } } - return len(conns) + return 0 } func (t *Transport) IdleConnChMapSizeForTesting() int { @@ -63,4 +64,9 @@ func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { return &timeoutHandler{handler, f, ""} } +func ResetCachedEnvironment() { + httpProxyEnv.reset() + noProxyEnv.reset() +} + var DefaultUserAgent = defaultUserAgent diff --git a/src/pkg/net/http/fcgi/child.go b/src/pkg/net/http/fcgi/child.go index 60b794e07..a3beaa33a 100644 --- a/src/pkg/net/http/fcgi/child.go +++ b/src/pkg/net/http/fcgi/child.go @@ -16,6 +16,7 @@ import ( "net/http/cgi" "os" "strings" + "sync" "time" ) @@ -126,8 +127,10 @@ func (r *response) Close() error { } type child struct { - conn *conn - handler http.Handler + conn *conn + handler http.Handler + + mu sync.Mutex // protects requests: requests map[uint16]*request // keyed by request ID } @@ -157,7 +160,9 @@ var errCloseConn = errors.New("fcgi: connection should be closed") var emptyBody = ioutil.NopCloser(strings.NewReader("")) func (c *child) handleRecord(rec *record) error { + c.mu.Lock() req, ok := c.requests[rec.h.Id] + c.mu.Unlock() if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { // The spec says to ignore unknown request IDs. return nil @@ -179,7 +184,10 @@ func (c *child) handleRecord(rec *record) error { c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) return nil } - c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + req = newRequest(rec.h.Id, br.flags) + c.mu.Lock() + c.requests[rec.h.Id] = req + c.mu.Unlock() return nil case typeParams: // NOTE(eds): Technically a key-value pair can straddle the boundary @@ -220,7 +228,9 @@ func (c *child) handleRecord(rec *record) error { return nil case typeAbortRequest: println("abort") + c.mu.Lock() delete(c.requests, rec.h.Id) + c.mu.Unlock() c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) if !req.keepConn { // connection will close upon return @@ -247,6 +257,9 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) { c.handler.ServeHTTP(r, httpReq) } r.Close() + c.mu.Lock() + delete(c.requests, req.reqId) + c.mu.Unlock() c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) // Consume the entire body, so the host isn't still writing to diff --git a/src/pkg/net/http/fs.go b/src/pkg/net/http/fs.go index 8b32ca1d0..8576cf844 100644 --- a/src/pkg/net/http/fs.go +++ b/src/pkg/net/http/fs.go @@ -13,6 +13,7 @@ import ( "mime" "mime/multipart" "net/textproto" + "net/url" "os" "path" "path/filepath" @@ -52,12 +53,14 @@ type FileSystem interface { // A File is returned by a FileSystem's Open method and can be // served by the FileServer implementation. +// +// The methods should behave the same as those on an *os.File. type File interface { - Close() error - Stat() (os.FileInfo, error) + io.Closer + io.Reader Readdir(count int) ([]os.FileInfo, error) - Read([]byte) (int, error) Seek(offset int64, whence int) (int64, error) + Stat() (os.FileInfo, error) } func dirList(w ResponseWriter, f File) { @@ -73,8 +76,11 @@ func dirList(w ResponseWriter, f File) { if d.IsDir() { name += "/" } - // TODO htmlescape - fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", name, name) + // name may contain '?' or '#', which must be escaped to remain + // part of the URL path, and not indicate the start of a query + // string or fragment. + url := url.URL{Path: name} + fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name)) } } fmt.Fprintf(w, "</pre>\n") @@ -521,7 +527,7 @@ func (w *countingWriter) Write(p []byte) (n int, err error) { return len(p), nil } -// rangesMIMESize returns the nunber of bytes it takes to encode the +// rangesMIMESize returns the number of bytes it takes to encode the // provided ranges as a multipart response. func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { var w countingWriter diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go index ae54edf0c..f968565f9 100644 --- a/src/pkg/net/http/fs_test.go +++ b/src/pkg/net/http/fs_test.go @@ -227,6 +227,54 @@ func TestFileServerCleans(t *testing.T) { } } +func TestFileServerEscapesNames(t *testing.T) { + defer afterTest(t) + const dirListPrefix = "<pre>\n" + const dirListSuffix = "\n</pre>\n" + tests := []struct { + name, escaped string + }{ + {`simple_name`, `<a href="simple_name">simple_name</a>`}, + {`"'<>&`, `<a href="%22%27%3C%3E&">"'<>&</a>`}, + {`?foo=bar#baz`, `<a href="%3Ffoo=bar%23baz">?foo=bar#baz</a>`}, + {`<combo>?foo`, `<a href="%3Ccombo%3E%3Ffoo"><combo>?foo</a>`}, + } + + // We put each test file in its own directory in the fakeFS so we can look at it in isolation. + fs := make(fakeFS) + for i, test := range tests { + testFile := &fakeFileInfo{basename: test.name} + fs[fmt.Sprintf("/%d", i)] = &fakeFileInfo{ + dir: true, + modtime: time.Unix(1000000000, 0).UTC(), + ents: []*fakeFileInfo{testFile}, + } + fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile + } + + ts := httptest.NewServer(FileServer(&fs)) + defer ts.Close() + for i, test := range tests { + url := fmt.Sprintf("%s/%d", ts.URL, i) + res, err := Get(url) + if err != nil { + t.Fatalf("test %q: Get: %v", test.name, err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("test %q: read Body: %v", test.name, err) + } + s := string(b) + if !strings.HasPrefix(s, dirListPrefix) || !strings.HasSuffix(s, dirListSuffix) { + t.Errorf("test %q: listing dir, full output is %q, want prefix %q and suffix %q", test.name, s, dirListPrefix, dirListSuffix) + } + if trimmed := strings.TrimSuffix(strings.TrimPrefix(s, dirListPrefix), dirListSuffix); trimmed != test.escaped { + t.Errorf("test %q: listing dir, filename escaped to %q, want %q", test.name, trimmed, test.escaped) + } + res.Body.Close() + } +} + func mustRemoveAll(dir string) { err := os.RemoveAll(dir) if err != nil { @@ -457,8 +505,9 @@ func (f *fakeFileInfo) Mode() os.FileMode { type fakeFile struct { io.ReadSeeker - fi *fakeFileInfo - path string // as opened + fi *fakeFileInfo + path string // as opened + entpos int } func (f *fakeFile) Close() error { return nil } @@ -468,10 +517,20 @@ func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { return nil, os.ErrInvalid } var fis []os.FileInfo - for _, fi := range f.fi.ents { - fis = append(fis, fi) + + limit := f.entpos + count + if count <= 0 || limit > len(f.fi.ents) { + limit = len(f.fi.ents) + } + for ; f.entpos < limit; f.entpos++ { + fis = append(fis, f.fi.ents[f.entpos]) + } + + if len(fis) == 0 && count > 0 { + return fis, io.EOF + } else { + return fis, nil } - return fis, nil } type fakeFS map[string]*fakeFileInfo @@ -480,7 +539,6 @@ func (fs fakeFS) Open(name string) (File, error) { name = path.Clean(name) f, ok := fs[name] if !ok { - println("fake filesystem didn't find file", name) return nil, os.ErrNotExist } return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go index ca1ae07c2..153b94370 100644 --- a/src/pkg/net/http/header.go +++ b/src/pkg/net/http/header.go @@ -9,9 +9,12 @@ import ( "net/textproto" "sort" "strings" + "sync" "time" ) +var raceEnabled = false // set by race.go + // A Header represents the key-value pairs in an HTTP header. type Header map[string][]string @@ -114,18 +117,15 @@ func (s *headerSorter) Len() int { return len(s.kvs) } func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } -// TODO: convert this to a sync.Cache (issue 4720) -var headerSorterCache = make(chan *headerSorter, 8) +var headerSorterPool = sync.Pool{ + New: func() interface{} { return new(headerSorter) }, +} // sortedKeyValues returns h's keys sorted in the returned kvs // slice. The headerSorter used to sort is also returned, for possible // return to headerSorterCache. func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { - select { - case hs = <-headerSorterCache: - default: - hs = new(headerSorter) - } + hs = headerSorterPool.Get().(*headerSorter) if cap(hs.kvs) < len(h) { hs.kvs = make([]keyValues, 0, len(h)) } @@ -159,10 +159,7 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { } } } - select { - case headerSorterCache <- sorter: - default: - } + headerSorterPool.Put(sorter) return nil } diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go index 9fd9837a5..9dcd591fa 100644 --- a/src/pkg/net/http/header_test.go +++ b/src/pkg/net/http/header_test.go @@ -192,9 +192,12 @@ func BenchmarkHeaderWriteSubset(b *testing.B) { } } -func TestHeaderWriteSubsetMallocs(t *testing.T) { +func TestHeaderWriteSubsetAllocs(t *testing.T) { if testing.Short() { - t.Skip("skipping malloc count in short mode") + t.Skip("skipping alloc test in short mode") + } + if raceEnabled { + t.Skip("skipping test under race detector") } if runtime.GOMAXPROCS(0) > 1 { t.Skip("skipping; GOMAXPROCS>1") @@ -204,6 +207,6 @@ func TestHeaderWriteSubsetMallocs(t *testing.T) { testHeader.WriteSubset(&buf, nil) }) if n > 0 { - t.Errorf("mallocs = %g; want 0", n) + t.Errorf("allocs = %g; want 0", n) } } diff --git a/src/pkg/net/http/httptest/server_test.go b/src/pkg/net/http/httptest/server_test.go index 500a9f0b8..501cc8a99 100644 --- a/src/pkg/net/http/httptest/server_test.go +++ b/src/pkg/net/http/httptest/server_test.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "net/http" "testing" + "time" ) func TestServer(t *testing.T) { @@ -27,3 +28,25 @@ func TestServer(t *testing.T) { t.Errorf("got %q, want hello", string(got)) } } + +func TestIssue7264(t *testing.T) { + for i := 0; i < 1000; i++ { + func() { + inHandler := make(chan bool, 1) + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + inHandler <- true + })) + defer ts.Close() + tr := &http.Transport{ + ResponseHeaderTimeout: time.Nanosecond, + } + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + res, err := c.Get(ts.URL) + <-inHandler + if err == nil { + res.Body.Close() + } + }() + } +} diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go index b66d40951..9632bfd19 100644 --- a/src/pkg/net/http/httputil/chunked.go +++ b/src/pkg/net/http/httputil/chunked.go @@ -4,15 +4,14 @@ // The wire protocol for HTTP's "chunked" Transfer-Encoding. -// This code is a duplicate of ../chunked.go with these edits: -// s/newChunked/NewChunked/g -// s/package http/package httputil/ +// This code is duplicated in net/http and net/http/httputil. // Please make any changes in both files. package httputil import ( "bufio" + "bytes" "errors" "fmt" "io" @@ -22,13 +21,13 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize var ErrLineTooLong = errors.New("header line too long") -// NewChunkedReader returns a new chunkedReader that translates the data read from r +// newChunkedReader returns a new chunkedReader that translates the data read from r // out of HTTP "chunked" format before returning it. // The chunkedReader returns io.EOF when the final 0-length chunk is read. // -// NewChunkedReader is not needed by normal applications. The http package +// newChunkedReader is not needed by normal applications. The http package // automatically decodes chunking when reading response bodies. -func NewChunkedReader(r io.Reader) io.Reader { +func newChunkedReader(r io.Reader) io.Reader { br, ok := r.(*bufio.Reader) if !ok { br = bufio.NewReader(r) @@ -59,26 +58,45 @@ func (cr *chunkedReader) beginChunk() { } } -func (cr *chunkedReader) Read(b []uint8) (n int, err error) { - if cr.err != nil { - return 0, cr.err +func (cr *chunkedReader) chunkHeaderAvailable() bool { + n := cr.r.Buffered() + if n > 0 { + peek, _ := cr.r.Peek(n) + return bytes.IndexByte(peek, '\n') >= 0 } - if cr.n == 0 { - cr.beginChunk() - if cr.err != nil { - return 0, cr.err + return false +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + for cr.err == nil { + if cr.n == 0 { + if n > 0 && !cr.chunkHeaderAvailable() { + // We've read enough. Don't potentially block + // reading a new chunk header. + break + } + cr.beginChunk() + continue } - } - if uint64(len(b)) > cr.n { - b = b[0:cr.n] - } - n, cr.err = cr.r.Read(b) - cr.n -= uint64(n) - if cr.n == 0 && cr.err == nil { - // end of chunk (CRLF) - if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil { - if cr.buf[0] != '\r' || cr.buf[1] != '\n' { - cr.err = errors.New("malformed chunked encoding") + if len(b) == 0 { + break + } + rbuf := b + if uint64(len(rbuf)) > cr.n { + rbuf = rbuf[:cr.n] + } + var n0 int + n0, cr.err = cr.r.Read(rbuf) + n += n0 + b = b[n0:] + cr.n -= uint64(n0) + // If we're at the end of a chunk, read the next two + // bytes to verify they are "\r\n". + if cr.n == 0 && cr.err == nil { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { + cr.err = errors.New("malformed chunked encoding") + } } } } @@ -117,16 +135,16 @@ func isASCIISpace(b byte) bool { return b == ' ' || b == '\t' || b == '\n' || b == '\r' } -// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP // "chunked" format before writing them to w. Closing the returned chunkedWriter // sends the final 0-length chunk that marks the end of the stream. // -// NewChunkedWriter is not needed by normal applications. The http +// newChunkedWriter is not needed by normal applications. The http // package adds chunking automatically if handlers don't set a -// Content-Length header. Using NewChunkedWriter inside a handler +// Content-Length header. Using newChunkedWriter inside a handler // would result in double chunking or chunking with a Content-Length // length, both of which are wrong. -func NewChunkedWriter(w io.Writer) io.WriteCloser { +func newChunkedWriter(w io.Writer) io.WriteCloser { return &chunkedWriter{w} } diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go index a06bffad5..a7a577468 100644 --- a/src/pkg/net/http/httputil/chunked_test.go +++ b/src/pkg/net/http/httputil/chunked_test.go @@ -2,26 +2,25 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This code is a duplicate of ../chunked_test.go with these edits: -// s/newChunked/NewChunked/g -// s/package http/package httputil/ +// This code is duplicated in net/http and net/http/httputil. // Please make any changes in both files. package httputil import ( + "bufio" "bytes" "fmt" "io" "io/ioutil" - "runtime" + "strings" "testing" ) func TestChunk(t *testing.T) { var b bytes.Buffer - w := NewChunkedWriter(&b) + w := newChunkedWriter(&b) const chunk1 = "hello, " const chunk2 = "world! 0123456789abcdef" w.Write([]byte(chunk1)) @@ -32,7 +31,7 @@ func TestChunk(t *testing.T) { t.Fatalf("chunk writer wrote %q; want %q", g, e) } - r := NewChunkedReader(&b) + r := newChunkedReader(&b) data, err := ioutil.ReadAll(r) if err != nil { t.Logf(`data: "%s"`, data) @@ -43,37 +42,102 @@ func TestChunk(t *testing.T) { } } +func TestChunkReadMultiple(t *testing.T) { + // Bunch of small chunks, all read together. + { + var b bytes.Buffer + w := newChunkedWriter(&b) + w.Write([]byte("foo")) + w.Write([]byte("bar")) + w.Close() + + r := newChunkedReader(&b) + buf := make([]byte, 10) + n, err := r.Read(buf) + if n != 6 || err != io.EOF { + t.Errorf("Read = %d, %v; want 6, EOF", n, err) + } + buf = buf[:n] + if string(buf) != "foobar" { + t.Errorf("Read = %q; want %q", buf, "foobar") + } + } + + // One big chunk followed by a little chunk, but the small bufio.Reader size + // should prevent the second chunk header from being read. + { + var b bytes.Buffer + w := newChunkedWriter(&b) + // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes, + // the same as the bufio ReaderSize below (the minimum), so even + // though we're going to try to Read with a buffer larger enough to also + // receive "foo", the second chunk header won't be read yet. + const fillBufChunk = "0123456789a" + const shortChunk = "foo" + w.Write([]byte(fillBufChunk)) + w.Write([]byte(shortChunk)) + w.Close() + + r := newChunkedReader(bufio.NewReaderSize(&b, 16)) + buf := make([]byte, len(fillBufChunk)+len(shortChunk)) + n, err := r.Read(buf) + if n != len(fillBufChunk) || err != nil { + t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk)) + } + buf = buf[:n] + if string(buf) != fillBufChunk { + t.Errorf("Read = %q; want %q", buf, fillBufChunk) + } + + n, err = r.Read(buf) + if n != len(shortChunk) || err != io.EOF { + t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk)) + } + } + + // And test that we see an EOF chunk, even though our buffer is already full: + { + r := newChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n"))) + buf := make([]byte, 3) + n, err := r.Read(buf) + if n != 3 || err != io.EOF { + t.Errorf("Read = %d, %v; want 3, EOF", n, err) + } + if string(buf) != "foo" { + t.Errorf("buf = %q; want foo", buf) + } + } +} + func TestChunkReaderAllocs(t *testing.T) { - // temporarily set GOMAXPROCS to 1 as we are testing memory allocations - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + if testing.Short() { + t.Skip("skipping in short mode") + } var buf bytes.Buffer - w := NewChunkedWriter(&buf) + w := newChunkedWriter(&buf) a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") w.Write(a) w.Write(b) w.Write(c) w.Close() - r := NewChunkedReader(&buf) readBuf := make([]byte, len(a)+len(b)+len(c)+1) - - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - m0 := ms.Mallocs - - n, err := io.ReadFull(r, readBuf) - - runtime.ReadMemStats(&ms) - mallocs := ms.Mallocs - m0 - if mallocs > 1 { - t.Errorf("%d mallocs; want <= 1", mallocs) - } - - if n != len(readBuf)-1 { - t.Errorf("read %d bytes; want %d", n, len(readBuf)-1) - } - if err != io.ErrUnexpectedEOF { - t.Errorf("read error = %v; want ErrUnexpectedEOF", err) + byter := bytes.NewReader(buf.Bytes()) + bufr := bufio.NewReader(byter) + mallocs := testing.AllocsPerRun(100, func() { + byter.Seek(0, 0) + bufr.Reset(byter) + r := newChunkedReader(bufr) + n, err := io.ReadFull(r, readBuf) + if n != len(readBuf)-1 { + t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Fatalf("read error = %v; want ErrUnexpectedEOF", err) + } + }) + if mallocs > 1.5 { + t.Errorf("mallocs = %v; want 1", mallocs) } } diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go index 265499fb0..2a7a413d0 100644 --- a/src/pkg/net/http/httputil/dump.go +++ b/src/pkg/net/http/httputil/dump.go @@ -7,6 +7,7 @@ package httputil import ( "bufio" "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -29,7 +30,7 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { if err = b.Close(); err != nil { return nil, nil, err } - return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewBuffer(buf.Bytes())), nil + return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil } // dumpConn is a net.Conn which writes to Writer and reads from Reader @@ -106,6 +107,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil }, } + defer t.CloseIdleConnections() _, err := t.RoundTrip(reqSend) @@ -230,14 +232,31 @@ func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { return } +// errNoBody is a sentinel error value used by failureToReadBody so we can detect +// that the lack of body was intentional. +var errNoBody = errors.New("sentinel error value") + +// failureToReadBody is a io.ReadCloser that just returns errNoBody on +// Read. It's swapped in when we don't actually want to consume the +// body, but need a non-nil one, and want to distinguish the error +// from reading the dummy body. +type failureToReadBody struct{} + +func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } +func (failureToReadBody) Close() error { return nil } + +var emptyBody = ioutil.NopCloser(strings.NewReader("")) + // DumpResponse is like DumpRequest but dumps a response. func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) { var b bytes.Buffer save := resp.Body savecl := resp.ContentLength - if !body || resp.Body == nil { - resp.Body = nil - resp.ContentLength = 0 + + if !body { + resp.Body = failureToReadBody{} + } else if resp.Body == nil { + resp.Body = emptyBody } else { save, resp.Body, err = drainBody(resp.Body) if err != nil { @@ -245,11 +264,13 @@ func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) { } } err = resp.Write(&b) + if err == errNoBody { + err = nil + } resp.Body = save resp.ContentLength = savecl if err != nil { - return + return nil, err } - dump = b.Bytes() - return + return b.Bytes(), nil } diff --git a/src/pkg/net/http/httputil/dump_test.go b/src/pkg/net/http/httputil/dump_test.go index 987a82048..e1ffb3935 100644 --- a/src/pkg/net/http/httputil/dump_test.go +++ b/src/pkg/net/http/httputil/dump_test.go @@ -11,6 +11,8 @@ import ( "io/ioutil" "net/http" "net/url" + "runtime" + "strings" "testing" ) @@ -112,6 +114,7 @@ var dumpTests = []dumpTest{ } func TestDumpRequest(t *testing.T) { + numg0 := runtime.NumGoroutine() for i, tt := range dumpTests { setBody := func() { if tt.Body == nil { @@ -119,7 +122,7 @@ func TestDumpRequest(t *testing.T) { } switch b := tt.Body.(type) { case []byte: - tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b)) + tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: tt.Req.Body = b() } @@ -155,6 +158,9 @@ func TestDumpRequest(t *testing.T) { } } } + if dg := runtime.NumGoroutine() - numg0; dg > 4 { + t.Errorf("Unexpectedly large number of new goroutines: %d new", dg) + } } func chunk(s string) string { @@ -176,3 +182,82 @@ func mustNewRequest(method, url string, body io.Reader) *http.Request { } return req } + +var dumpResTests = []struct { + res *http.Response + body bool + want string +}{ + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 50, + Header: http.Header{ + "Foo": []string{"Bar"}, + }, + Body: ioutil.NopCloser(strings.NewReader("foo")), // shouldn't be used + }, + body: false, // to verify we see 50, not empty or 3. + want: `HTTP/1.1 200 OK +Content-Length: 50 +Foo: Bar`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 3, + Body: ioutil.NopCloser(strings.NewReader("foo")), + }, + body: true, + want: `HTTP/1.1 200 OK +Content-Length: 3 + +foo`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: -1, + Body: ioutil.NopCloser(strings.NewReader("foo")), + TransferEncoding: []string{"chunked"}, + }, + body: true, + want: `HTTP/1.1 200 OK +Transfer-Encoding: chunked + +3 +foo +0`, + }, +} + +func TestDumpResponse(t *testing.T) { + for i, tt := range dumpResTests { + gotb, err := DumpResponse(tt.res, tt.body) + if err != nil { + t.Errorf("%d. DumpResponse = %v", i, err) + continue + } + got := string(gotb) + got = strings.TrimSpace(got) + got = strings.Replace(got, "\r", "", -1) + + if got != tt.want { + t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want) + } + } +} diff --git a/src/pkg/net/http/httputil/httputil.go b/src/pkg/net/http/httputil/httputil.go new file mode 100644 index 000000000..74fb6c655 --- /dev/null +++ b/src/pkg/net/http/httputil/httputil.go @@ -0,0 +1,32 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httputil provides HTTP utility functions, complementing the +// more common ones in the net/http package. +package httputil + +import "io" + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + return newChunkedReader(r) +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using NewChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return newChunkedWriter(w) +} diff --git a/src/pkg/net/http/httputil/persist.go b/src/pkg/net/http/httputil/persist.go index 507938aca..987bcc96b 100644 --- a/src/pkg/net/http/httputil/persist.go +++ b/src/pkg/net/http/httputil/persist.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package httputil provides HTTP utility functions, complementing the -// more common ones in the net/http package. package httputil import ( @@ -33,8 +31,8 @@ var errClosed = errors.New("i/o operation on closed connection") // i.e. requests can be read out of sync (but in the same order) while the // respective responses are sent. // -// ServerConn is low-level and should not be needed by most applications. -// See Server. +// ServerConn is low-level and old. Applications should instead use Server +// in the net/http package. type ServerConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn @@ -47,8 +45,11 @@ type ServerConn struct { pipe textproto.Pipeline } -// NewServerConn returns a new ServerConn reading and writing c. If r is not +// NewServerConn returns a new ServerConn reading and writing c. If r is not // nil, it is the buffer to use when reading c. +// +// ServerConn is low-level and old. Applications should instead use Server +// in the net/http package. func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { if r == nil { r = bufio.NewReader(c) @@ -223,8 +224,8 @@ func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { // supports hijacking the connection calling Hijack to // regain control of the underlying net.Conn and deal with it as desired. // -// ClientConn is low-level and should not be needed by most applications. -// See Client. +// ClientConn is low-level and old. Applications should instead use +// Client or Transport in the net/http package. type ClientConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn @@ -240,6 +241,9 @@ type ClientConn struct { // NewClientConn returns a new ClientConn reading and writing c. If r is not // nil, it is the buffer to use when reading c. +// +// ClientConn is low-level and old. Applications should use Client or +// Transport in the net/http package. func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { if r == nil { r = bufio.NewReader(c) @@ -254,6 +258,9 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { // NewProxyClientConn works like NewClientConn but writes Requests // using Request's WriteProxy method. +// +// New code should not use NewProxyClientConn. See Client or +// Transport in the net/http package instead. func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { cc := NewClientConn(c, r) cc.writeReq = (*http.Request).WriteProxy diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go index 1990f64db..48ada5f5f 100644 --- a/src/pkg/net/http/httputil/reverseproxy.go +++ b/src/pkg/net/http/httputil/reverseproxy.go @@ -144,6 +144,10 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } defer res.Body.Close() + for _, h := range hopHeaders { + res.Header.Del(h) + } + copyHeader(rw.Header(), res.Header) rw.WriteHeader(res.StatusCode) diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go index 1c0444ec4..e9539b44b 100644 --- a/src/pkg/net/http/httputil/reverseproxy_test.go +++ b/src/pkg/net/http/httputil/reverseproxy_test.go @@ -16,6 +16,12 @@ import ( "time" ) +const fakeHopHeader = "X-Fake-Hop-Header-For-Test" + +func init() { + hopHeaders = append(hopHeaders, fakeHopHeader) +} + func TestReverseProxy(t *testing.T) { const backendResponse = "I am the backend" const backendStatus = 404 @@ -36,6 +42,10 @@ func TestReverseProxy(t *testing.T) { t.Errorf("backend got Host header %q, want %q", g, e) } w.Header().Set("X-Foo", "bar") + w.Header().Set("Upgrade", "foo") + w.Header().Set(fakeHopHeader, "foo") + w.Header().Add("X-Multi-Value", "foo") + w.Header().Add("X-Multi-Value", "bar") http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) w.WriteHeader(backendStatus) w.Write([]byte(backendResponse)) @@ -64,6 +74,12 @@ func TestReverseProxy(t *testing.T) { if g, e := res.Header.Get("X-Foo"), "bar"; g != e { t.Errorf("got X-Foo %q; expected %q", g, e) } + if c := res.Header.Get(fakeHopHeader); c != "" { + t.Errorf("got %s header value %q", fakeHopHeader, c) + } + if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { + t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) + } if g, e := len(res.Header["Set-Cookie"]), 1; g != e { t.Fatalf("got %d SetCookies, want %d", g, e) } diff --git a/src/pkg/net/http/proxy_test.go b/src/pkg/net/http/proxy_test.go index 449ccaeea..b6aed3792 100644 --- a/src/pkg/net/http/proxy_test.go +++ b/src/pkg/net/http/proxy_test.go @@ -35,12 +35,8 @@ var UseProxyTests = []struct { } func TestUseProxy(t *testing.T) { - oldenv := os.Getenv("NO_PROXY") - defer os.Setenv("NO_PROXY", oldenv) - - no_proxy := "foobar.com, .barbaz.net" - os.Setenv("NO_PROXY", no_proxy) - + ResetProxyEnv() + os.Setenv("NO_PROXY", "foobar.com, .barbaz.net") for _, test := range UseProxyTests { if useProxy(test.host+":80") != test.match { t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) @@ -71,8 +67,15 @@ func TestCacheKeys(t *testing.T) { proxy = u } cm := connectMethod{proxy, tt.scheme, tt.addr} - if cm.String() != tt.key { - t.Fatalf("{%q, %q, %q} cache key %q; want %q", tt.proxy, tt.scheme, tt.addr, cm.String(), tt.key) + if got := cm.key().String(); got != tt.key { + t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key) } } } + +func ResetProxyEnv() { + for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy"} { + os.Setenv(v, "") + } + ResetCachedEnvironment() +} diff --git a/src/pkg/net/http/race.go b/src/pkg/net/http/race.go new file mode 100644 index 000000000..766503967 --- /dev/null +++ b/src/pkg/net/http/race.go @@ -0,0 +1,11 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package http + +func init() { + raceEnabled = true +} diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go index 57b5d0948..a67092066 100644 --- a/src/pkg/net/http/request.go +++ b/src/pkg/net/http/request.go @@ -20,6 +20,7 @@ import ( "net/url" "strconv" "strings" + "sync" ) const ( @@ -68,18 +69,31 @@ var reqWriteExcludeHeader = map[string]bool{ // A Request represents an HTTP request received by a server // or to be sent by a client. +// +// The field semantics differ slightly between client and server +// usage. In addition to the notes on the fields below, see the +// documentation for Request.Write and RoundTripper. type Request struct { - Method string // GET, POST, PUT, etc. + // Method specifies the HTTP method (GET, POST, PUT, etc.). + // For client requests an empty string means GET. + Method string - // URL is created from the URI supplied on the Request-Line - // as stored in RequestURI. + // URL specifies either the URI being requested (for server + // requests) or the URL to access (for client requests). + // + // For server requests the URL is parsed from the URI + // supplied on the Request-Line as stored in RequestURI. For + // most requests, fields other than Path and RawQuery will be + // empty. (See RFC 2616, Section 5.1.2) // - // For most requests, fields other than Path and RawQuery - // will be empty. (See RFC 2616, Section 5.1.2) + // For client requests, the URL's Host specifies the server to + // connect to, while the Request's Host field optionally + // specifies the Host header value to send in the HTTP + // request. URL *url.URL // The protocol version for incoming requests. - // Outgoing requests always use HTTP/1.1. + // Client requests always use HTTP/1.1. Proto string // "HTTP/1.0" ProtoMajor int // 1 ProtoMinor int // 0 @@ -103,15 +117,20 @@ type Request struct { // The request parser implements this by canonicalizing the // name, making the first character and any characters // following a hyphen uppercase and the rest lowercase. + // + // For client requests certain headers are automatically + // added and may override values in Header. + // + // See the documentation for the Request.Write method. Header Header // Body is the request's body. // - // For client requests, a nil body means the request has no + // For client requests a nil body means the request has no // body, such as a GET request. The HTTP Client's Transport // is responsible for calling the Close method. // - // For server requests, the Request Body is always non-nil + // For server requests the Request Body is always non-nil // but will return EOF immediately when no body is present. // The Server will close the request body. The ServeHTTP // Handler does not need to. @@ -121,7 +140,7 @@ type Request struct { // The value -1 indicates that the length is unknown. // Values >= 0 indicate that the given number of bytes may // be read from Body. - // For outgoing requests, a value of 0 means unknown if Body is not nil. + // For client requests, a value of 0 means unknown if Body is not nil. ContentLength int64 // TransferEncoding lists the transfer encodings from outermost to @@ -132,13 +151,18 @@ type Request struct { TransferEncoding []string // Close indicates whether to close the connection after - // replying to this request. + // replying to this request (for servers) or after sending + // the request (for clients). Close bool - // The host on which the URL is sought. - // Per RFC 2616, this is either the value of the Host: header - // or the host name given in the URL itself. + // For server requests Host specifies the host on which the + // URL is sought. Per RFC 2616, this is either the value of + // the "Host" header or the host name given in the URL itself. // It may be of the form "host:port". + // + // For client requests Host optionally overrides the Host + // header to send. If empty, the Request.Write method uses + // the value of URL.Host. Host string // Form contains the parsed form data, including both the URL @@ -158,12 +182,24 @@ type Request struct { // The HTTP client ignores MultipartForm and uses Body instead. MultipartForm *multipart.Form - // Trailer maps trailer keys to values. Like for Header, if the - // response has multiple trailer lines with the same key, they will be - // concatenated, delimited by commas. - // For server requests, Trailer is only populated after Body has been - // closed or fully consumed. - // Trailer support is only partially complete. + // Trailer specifies additional headers that are sent after the request + // body. + // + // For server requests the Trailer map initially contains only the + // trailer keys, with nil values. (The client declares which trailers it + // will later send.) While the handler is reading from Body, it must + // not reference Trailer. After reading from Body returns EOF, Trailer + // can be read again and will contain non-nil values, if they were sent + // by the client. + // + // For client requests Trailer must be initialized to a map containing + // the trailer keys to later send. The values may be nil or their final + // values. The ContentLength must be 0 or -1, to send a chunked request. + // After the HTTP request is sent the map values can be updated while + // the request body is read. Once the body returns EOF, the caller must + // not mutate Trailer. + // + // Few HTTP clients, servers, or proxies support HTTP trailers. Trailer Header // RemoteAddr allows HTTP servers and other software to record @@ -381,7 +417,6 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err return err } - // TODO: split long values? (If so, should share code with Conn.Write) err = req.Header.WriteSubset(w, reqWriteExcludeHeader) if err != nil { return err @@ -494,25 +529,20 @@ func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { return line[:s1], line[s1+1 : s2], line[s2+1:], true } -// TODO(bradfitz): use a sync.Cache when available -var textprotoReaderCache = make(chan *textproto.Reader, 4) +var textprotoReaderPool sync.Pool func newTextprotoReader(br *bufio.Reader) *textproto.Reader { - select { - case r := <-textprotoReaderCache: - r.R = br - return r - default: - return textproto.NewReader(br) + if v := textprotoReaderPool.Get(); v != nil { + tr := v.(*textproto.Reader) + tr.R = br + return tr } + return textproto.NewReader(br) } func putTextprotoReader(r *textproto.Reader) { r.R = nil - select { - case textprotoReaderCache <- r: - default: - } + textprotoReaderPool.Put(r) } // ReadRequest reads and parses a request from b. @@ -588,32 +618,6 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { fixPragmaCacheControl(req.Header) - // TODO: Parse specific header values: - // Accept - // Accept-Encoding - // Accept-Language - // Authorization - // Cache-Control - // Connection - // Date - // Expect - // From - // If-Match - // If-Modified-Since - // If-None-Match - // If-Range - // If-Unmodified-Since - // Max-Forwards - // Proxy-Authorization - // Referer [sic] - // TE (transfer-codings) - // Trailer - // Transfer-Encoding - // Upgrade - // User-Agent - // Via - // Warning - err = readTransfer(req, b) if err != nil { return nil, err @@ -677,6 +681,11 @@ func parsePostForm(r *Request) (vs url.Values, err error) { return } ct := r.Header.Get("Content-Type") + // RFC 2616, section 7.2.1 - empty type + // SHOULD be treated as application/octet-stream + if ct == "" { + ct = "application/octet-stream" + } ct, _, err = mime.ParseMediaType(ct) switch { case ct == "application/x-www-form-urlencoded": @@ -707,7 +716,7 @@ func parsePostForm(r *Request) (vs url.Values, err error) { // orders to call too many functions here. // Clean this up and write more tests. // request_test.go contains the start of this, - // in TestRequestMultipartCallOrder. + // in TestParseMultipartFormOrder and others. } return } @@ -727,7 +736,7 @@ func parsePostForm(r *Request) (vs url.Values, err error) { func (r *Request) ParseForm() error { var err error if r.PostForm == nil { - if r.Method == "POST" || r.Method == "PUT" { + if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" { r.PostForm, err = parsePostForm(r) } if r.PostForm == nil { @@ -780,9 +789,7 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error { } mr, err := r.multipartReader() - if err == ErrNotMultipart { - return nil - } else if err != nil { + if err != nil { return err } @@ -860,3 +867,9 @@ func (r *Request) wantsHttp10KeepAlive() bool { func (r *Request) wantsClose() bool { return hasToken(r.Header.get("Connection"), "close") } + +func (r *Request) closeBody() { + if r.Body != nil { + r.Body.Close() + } +} diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go index 89303c336..b9fa3c2bf 100644 --- a/src/pkg/net/http/request_test.go +++ b/src/pkg/net/http/request_test.go @@ -60,6 +60,37 @@ func TestPostQuery(t *testing.T) { } } +func TestPatchQuery(t *testing.T) { + req, _ := NewRequest("PATCH", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", + strings.NewReader("z=post&both=y&prio=2&empty=")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + if q := req.FormValue("q"); q != "foo" { + t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) + } + if z := req.FormValue("z"); z != "post" { + t.Errorf(`req.FormValue("z") = %q, want "post"`, z) + } + if bq, found := req.PostForm["q"]; found { + t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq) + } + if bz := req.PostFormValue("z"); bz != "post" { + t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz) + } + if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) { + t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs) + } + if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) { + t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both) + } + if prio := req.FormValue("prio"); prio != "2" { + t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) + } + if empty := req.FormValue("empty"); empty != "" { + t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) + } +} + type stringMap map[string][]string type parseContentTypeTest struct { shouldError bool @@ -68,8 +99,9 @@ type parseContentTypeTest struct { var parseContentTypeTests = []parseContentTypeTest{ {false, stringMap{"Content-Type": {"text/plain"}}}, - // Non-existent keys are not placed. The value nil is illegal. - {true, stringMap{}}, + // Empty content type is legal - shoult be treated as + // application/octet-stream (RFC 2616, section 7.2.1) + {false, stringMap{}}, {true, stringMap{"Content-Type": {"text/plain; boundary="}}}, {false, stringMap{"Content-Type": {"application/unknown"}}}, } @@ -79,7 +111,7 @@ func TestParseFormUnknownContentType(t *testing.T) { req := &Request{ Method: "POST", Header: Header(test.contentType), - Body: ioutil.NopCloser(bytes.NewBufferString("body")), + Body: ioutil.NopCloser(strings.NewReader("body")), } err := req.ParseForm() switch { @@ -122,7 +154,25 @@ func TestMultipartReader(t *testing.T) { req.Header = Header{"Content-Type": {"text/plain"}} multipart, err = req.MultipartReader() if multipart != nil { - t.Errorf("unexpected multipart for text/plain") + t.Error("unexpected multipart for text/plain") + } +} + +func TestParseMultipartForm(t *testing.T) { + req := &Request{ + Method: "POST", + Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, + Body: ioutil.NopCloser(new(bytes.Buffer)), + } + err := req.ParseMultipartForm(25) + if err == nil { + t.Error("expected multipart EOF, got nil") + } + + req.Header = Header{"Content-Type": {"text/plain"}} + err = req.ParseMultipartForm(25) + if err != ErrNotMultipart { + t.Error("expected ErrNotMultipart for text/plain") } } @@ -188,25 +238,72 @@ func TestMultipartRequestAuto(t *testing.T) { validateTestMultipartContents(t, req, true) } -func TestEmptyMultipartRequest(t *testing.T) { - // Test that FormValue and FormFile automatically invoke - // ParseMultipartForm and return the right values. - req, err := NewRequest("GET", "/", nil) - if err != nil { - t.Errorf("NewRequest err = %q", err) - } +func TestMissingFileMultipartRequest(t *testing.T) { + // Test that FormFile returns an error if + // the named file is missing. + req := newTestMultipartRequest(t) testMissingFile(t, req) } -func TestRequestMultipartCallOrder(t *testing.T) { +// Test that FormValue invokes ParseMultipartForm. +func TestFormValueCallsParseMultipartForm(t *testing.T) { + req, _ := NewRequest("POST", "http://www.google.com/", strings.NewReader("z=post")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + if req.Form != nil { + t.Fatal("Unexpected request Form, want nil") + } + req.FormValue("z") + if req.Form == nil { + t.Fatal("ParseMultipartForm not called by FormValue") + } +} + +// Test that FormFile invokes ParseMultipartForm. +func TestFormFileCallsParseMultipartForm(t *testing.T) { req := newTestMultipartRequest(t) - _, err := req.MultipartReader() - if err != nil { + if req.Form != nil { + t.Fatal("Unexpected request Form, want nil") + } + req.FormFile("") + if req.Form == nil { + t.Fatal("ParseMultipartForm not called by FormFile") + } +} + +// Test that ParseMultipartForm errors if called +// after MultipartReader on the same request. +func TestParseMultipartFormOrder(t *testing.T) { + req := newTestMultipartRequest(t) + if _, err := req.MultipartReader(); err != nil { t.Fatalf("MultipartReader: %v", err) } - err = req.ParseMultipartForm(1024) - if err == nil { - t.Errorf("expected an error from ParseMultipartForm after call to MultipartReader") + if err := req.ParseMultipartForm(1024); err == nil { + t.Fatal("expected an error from ParseMultipartForm after call to MultipartReader") + } +} + +// Test that MultipartReader errors if called +// after ParseMultipartForm on the same request. +func TestMultipartReaderOrder(t *testing.T) { + req := newTestMultipartRequest(t) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatalf("ParseMultipartForm: %v", err) + } + defer req.MultipartForm.RemoveAll() + if _, err := req.MultipartReader(); err == nil { + t.Fatal("expected an error from MultipartReader after call to ParseMultipartForm") + } +} + +// Test that FormFile errors if called after +// MultipartReader on the same request. +func TestFormFileOrder(t *testing.T) { + req := newTestMultipartRequest(t) + if _, err := req.MultipartReader(); err != nil { + t.Fatalf("MultipartReader: %v", err) + } + if _, _, err := req.FormFile(""); err == nil { + t.Fatal("expected an error from FormFile after call to MultipartReader") } } @@ -343,7 +440,7 @@ func testMissingFile(t *testing.T, req *Request) { } func newTestMultipartRequest(t *testing.T) *Request { - b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1)) + b := strings.NewReader(strings.Replace(message, "\n", "\r\n", -1)) req, err := NewRequest("POST", "/", b) if err != nil { t.Fatal("NewRequest:", err) diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go index b27b1f7ce..dc0e204ca 100644 --- a/src/pkg/net/http/requestwrite_test.go +++ b/src/pkg/net/http/requestwrite_test.go @@ -310,6 +310,46 @@ var reqWriteTests = []reqWriteTest{ WantError: errors.New("http: Request.ContentLength=5 with nil Body"), }, + // Request with a 0 ContentLength and a body with 1 byte content and an error. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { + err := errors.New("Custom reader error") + errReader := &errorReader{err} + return ioutil.NopCloser(io.MultiReader(strings.NewReader("x"), errReader)) + }, + + WantError: errors.New("Custom reader error"), + }, + + // Request with a 0 ContentLength and a body without content and an error. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { + err := errors.New("Custom reader error") + errReader := &errorReader{err} + return ioutil.NopCloser(errReader) + }, + + WantError: errors.New("Custom reader error"), + }, + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, // and doesn't add a User-Agent. { @@ -427,7 +467,7 @@ func TestRequestWrite(t *testing.T) { } switch b := tt.Body.(type) { case []byte: - tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b)) + tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: tt.Req.Body = b() } diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go index 35d0ba3bb..5d2c39080 100644 --- a/src/pkg/net/http/response.go +++ b/src/pkg/net/http/response.go @@ -8,6 +8,8 @@ package http import ( "bufio" + "bytes" + "crypto/tls" "errors" "io" "net/textproto" @@ -45,7 +47,8 @@ type Response struct { // // The http Client and Transport guarantee that Body is always // non-nil, even on responses without a body or responses with - // a zero-lengthed body. + // a zero-length body. It is the caller's responsibility to + // close Body. // // The Body is automatically dechunked if the server replied // with a "chunked" Transfer-Encoding. @@ -74,6 +77,12 @@ type Response struct { // Request's Body is nil (having already been consumed). // This is only populated for Client requests. Request *Request + + // TLS contains information about the TLS connection on which the + // response was received. It is nil for unencrypted responses. + // The pointer is shared between responses and should not be + // modified. + TLS *tls.ConnectionState } // Cookies parses and returns the cookies set in the Set-Cookie headers. @@ -141,6 +150,9 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { // Parse the response headers. mimeHeader, err := tp.ReadMIMEHeader() if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return nil, err } resp.Header = Header(mimeHeader) @@ -187,8 +199,8 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { // ContentLength // Header, values for non-canonical keys will have unpredictable behavior // +// Body is closed after it is sent. func (r *Response) Write(w io.Writer) error { - // Status line text := r.Status if text == "" { @@ -201,10 +213,45 @@ func (r *Response) Write(w io.Writer) error { protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor) statusCode := strconv.Itoa(r.StatusCode) + " " text = strings.TrimPrefix(text, statusCode) - io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n") + if _, err := io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n"); err != nil { + return err + } + + // Clone it, so we can modify r1 as needed. + r1 := new(Response) + *r1 = *r + if r1.ContentLength == 0 && r1.Body != nil { + // Is it actually 0 length? Or just unknown? + var buf [1]byte + n, err := r1.Body.Read(buf[:]) + if err != nil && err != io.EOF { + return err + } + if n == 0 { + // Reset it to a known zero reader, in case underlying one + // is unhappy being read repeatedly. + r1.Body = eofReader + } else { + r1.ContentLength = -1 + r1.Body = struct { + io.Reader + io.Closer + }{ + io.MultiReader(bytes.NewReader(buf[:1]), r.Body), + r.Body, + } + } + } + // If we're sending a non-chunked HTTP/1.1 response without a + // content-length, the only way to do that is the old HTTP/1.0 + // way, by noting the EOF with a connection close, so we need + // to set Close. + if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) { + r1.Close = true + } // Process Body,ContentLength,Close,Trailer - tw, err := newTransferWriter(r) + tw, err := newTransferWriter(r1) if err != nil { return err } @@ -219,8 +266,19 @@ func (r *Response) Write(w io.Writer) error { return err } + // contentLengthAlreadySent may have been already sent for + // POST/PUT requests, even if zero length. See Issue 8180. + contentLengthAlreadySent := tw.shouldSendContentLength() + if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent { + if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil { + return err + } + } + // End-of-header - io.WriteString(w, "\r\n") + if _, err := io.WriteString(w, "\r\n"); err != nil { + return err + } // Write body and trailer err = tw.WriteBody(w) diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go index 5044306a8..4b8946f7a 100644 --- a/src/pkg/net/http/response_test.go +++ b/src/pkg/net/http/response_test.go @@ -14,6 +14,7 @@ import ( "io/ioutil" "net/url" "reflect" + "regexp" "strings" "testing" ) @@ -28,6 +29,10 @@ func dummyReq(method string) *Request { return &Request{Method: method} } +func dummyReq11(method string) *Request { + return &Request{Method: method, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1} +} + var respTests = []respTest{ // Unchunked response without Content-Length. { @@ -406,8 +411,7 @@ func TestWriteResponse(t *testing.T) { t.Errorf("#%d: %v", i, err) continue } - bout := bytes.NewBuffer(nil) - err = resp.Write(bout) + err = resp.Write(ioutil.Discard) if err != nil { t.Errorf("#%d: %v", i, err) continue @@ -506,6 +510,9 @@ func TestReadResponseCloseInMiddle(t *testing.T) { rest, err := ioutil.ReadAll(bufr) checkErr(err, "ReadAll on remainder") if e, g := "Next Request Here", string(rest); e != g { + g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string { + return fmt.Sprintf("x(repeated x%d)", len(match)) + }) fatalf("remainder = %q, expected %q", g, e) } } @@ -615,6 +622,15 @@ func TestResponseContentLengthShortBody(t *testing.T) { } } +func TestReadResponseUnexpectedEOF(t *testing.T) { + br := bufio.NewReader(strings.NewReader("HTTP/1.1 301 Moved Permanently\r\n" + + "Location: http://example.com")) + _, err := ReadResponse(br, nil) + if err != io.ErrUnexpectedEOF { + t.Errorf("ReadResponse = %v; want io.ErrUnexpectedEOF", err) + } +} + func TestNeedsSniff(t *testing.T) { // needsSniff returns true with an empty response. r := &response{} diff --git a/src/pkg/net/http/responsewrite_test.go b/src/pkg/net/http/responsewrite_test.go index 5c10e2161..585b13b85 100644 --- a/src/pkg/net/http/responsewrite_test.go +++ b/src/pkg/net/http/responsewrite_test.go @@ -7,6 +7,7 @@ package http import ( "bytes" "io/ioutil" + "strings" "testing" ) @@ -25,7 +26,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + Body: ioutil.NopCloser(strings.NewReader("abcdef")), ContentLength: 6, }, @@ -41,13 +42,113 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + Body: ioutil.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, }, "HTTP/1.0 200 OK\r\n" + "\r\n" + "abcdef", }, + // HTTP/1.1 response with unknown length and Connection: close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + Close: true, + }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "abcdef", + }, + // HTTP/1.1 response with unknown length and not setting connection: close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "abcdef", + }, + // HTTP/1.1 response with unknown length and not setting connection: close, but + // setting chunked. + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + TransferEncoding: []string{"chunked"}, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "6\r\nabcdef\r\n0\r\n\r\n", + }, + // HTTP/1.1 response 0 content-length, and nil body + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: nil, + ContentLength: 0, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + }, + // HTTP/1.1 response 0 content-length, and non-nil empty body + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("")), + ContentLength: 0, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + }, + // HTTP/1.1 response 0 content-length, and non-nil non-empty body + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("foo")), + ContentLength: 0, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "\r\nfoo", + }, // HTTP/1.1, chunked coding; empty trailer; close { Response{ @@ -56,7 +157,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + Body: ioutil.NopCloser(strings.NewReader("abcdef")), ContentLength: 6, TransferEncoding: []string{"chunked"}, Close: true, @@ -90,6 +191,22 @@ func TestResponseWrite(t *testing.T) { "Foo: Bar Baz\r\n" + "\r\n", }, + + // Want a single Content-Length header. Fixing issue 8180 where + // there were two. + { + Response{ + StatusCode: StatusOK, + ProtoMajor: 1, + ProtoMinor: 1, + Request: &Request{Method: "POST"}, + Header: Header{}, + ContentLength: 0, + TransferEncoding: nil, + Body: nil, + }, + "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", + }, } for i := range respWriteTests { diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index 955112bc2..9e4d226bf 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -419,7 +419,7 @@ func TestServeMuxHandlerRedirects(t *testing.T) { func TestMuxRedirectLeadingSlashes(t *testing.T) { paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"} for _, path := range paths { - req, err := ReadRequest(bufio.NewReader(bytes.NewBufferString("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n"))) + req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n"))) if err != nil { t.Errorf("%s", err) } @@ -441,6 +441,9 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) reqNum := 0 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { @@ -517,6 +520,9 @@ func TestServerTimeouts(t *testing.T) { // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. func TestOnlyWriteTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) var conn net.Conn var afterTimeoutErrc = make(chan error, 1) @@ -840,9 +846,14 @@ func TestHeadResponses(t *testing.T) { } func TestTLSHandshakeTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.ErrorLog = log.New(errc, "", 0) ts.StartTLS() defer ts.Close() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -857,6 +868,14 @@ func TestTLSHandshakeTimeout(t *testing.T) { t.Errorf("Read = %d, %v; want an error and no bytes", n, err) } }) + select { + case v := <-errc: + if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { + t.Errorf("expected a TLS handshake timeout error; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } } func TestTLSServer(t *testing.T) { @@ -869,6 +888,7 @@ func TestTLSServer(t *testing.T) { } } })) + ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) defer ts.Close() // Connect an idle TCP connection to this server before we run @@ -913,31 +933,50 @@ func TestTLSServer(t *testing.T) { } type serverExpectTest struct { - contentLength int // of request body + contentLength int // of request body + chunked bool expectation string // e.g. "100-continue" readBody bool // whether handler should read the body (if false, sends StatusUnauthorized) expectedResponse string // expected substring in first line of http response } +func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest { + return serverExpectTest{ + contentLength: contentLength, + expectation: expectation, + readBody: readBody, + expectedResponse: expectedResponse, + } +} + var serverExpectTests = []serverExpectTest{ // Normal 100-continues, case-insensitive. - {100, "100-continue", true, "100 Continue"}, - {100, "100-cOntInUE", true, "100 Continue"}, + expectTest(100, "100-continue", true, "100 Continue"), + expectTest(100, "100-cOntInUE", true, "100 Continue"), // No 100-continue. - {100, "", true, "200 OK"}, + expectTest(100, "", true, "200 OK"), // 100-continue but requesting client to deny us, // so it never reads the body. - {100, "100-continue", false, "401 Unauthorized"}, + expectTest(100, "100-continue", false, "401 Unauthorized"), // Likewise without 100-continue: - {100, "", false, "401 Unauthorized"}, + expectTest(100, "", false, "401 Unauthorized"), // Non-standard expectations are failures - {0, "a-pony", false, "417 Expectation Failed"}, + expectTest(0, "a-pony", false, "417 Expectation Failed"), - // Expect-100 requested but no body - {0, "100-continue", true, "400 Bad Request"}, + // Expect-100 requested but no body (is apparently okay: Issue 7625) + expectTest(0, "100-continue", true, "200 OK"), + // Expect-100 requested but handler doesn't read the body + expectTest(0, "100-continue", false, "401 Unauthorized"), + // Expect-100 continue with no body, but a chunked body. + { + expectation: "100-continue", + readBody: true, + chunked: true, + expectedResponse: "100 Continue", + }, } // Tests that the server responds to the "Expect" request header @@ -966,21 +1005,38 @@ func TestServerExpect(t *testing.T) { // Only send the body immediately if we're acting like an HTTP client // that doesn't send 100-continue expectations. - writeBody := test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" + writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue" go func() { + contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength) + if test.chunked { + contentLen = "Transfer-Encoding: chunked" + } _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+ "Connection: close\r\n"+ - "Content-Length: %d\r\n"+ + "%s\r\n"+ "Expect: %s\r\nHost: foo\r\n\r\n", - test.readBody, test.contentLength, test.expectation) + test.readBody, contentLen, test.expectation) if err != nil { t.Errorf("On test %#v, error writing request headers: %v", test, err) return } if writeBody { + var targ io.WriteCloser = struct { + io.Writer + io.Closer + }{ + conn, + ioutil.NopCloser(nil), + } + if test.chunked { + targ = httputil.NewChunkedWriter(conn) + } body := strings.Repeat("A", test.contentLength) - _, err = fmt.Fprint(conn, body) + _, err = fmt.Fprint(targ, body) + if err == nil { + err = targ.Close() + } if err != nil { if !test.readBody { // Server likely already hung up on us. @@ -1414,6 +1470,9 @@ func TestRequestBodyLimit(t *testing.T) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. func TestClientWriteShutdown(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() @@ -1934,6 +1993,31 @@ func TestWriteAfterHijack(t *testing.T) { } } +func TestDoubleHijack(t *testing.T) { + req := reqBytes("GET / HTTP/1.1\nHost: golang.org") + var buf bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &buf, + closec: make(chan bool, 1), + } + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + conn, _, err := rw.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + _, _, err = rw.(Hijacker).Hijack() + if err == nil { + t.Errorf("got err = nil; want err != nil") + } + conn.Close() + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec +} + // http://code.google.com/p/go/issues/detail?id=5955 // Note that this does not test the "request too large" // exit path from the http server. This is intentional; @@ -2037,31 +2121,160 @@ func TestServerReaderFromOrder(t *testing.T) { } } -// Issue 6157 -func TestNoContentTypeOnNotModified(t *testing.T) { +// Issue 6157, Issue 6685 +func TestCodesPreventingContentTypeAndBody(t *testing.T) { + for _, code := range []int{StatusNotModified, StatusNoContent, StatusContinue} { + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/header" { + w.Header().Set("Content-Length", "123") + } + w.WriteHeader(code) + if r.URL.Path == "/more" { + w.Write([]byte("stuff")) + } + })) + for _, req := range []string{ + "GET / HTTP/1.0", + "GET /header HTTP/1.0", + "GET /more HTTP/1.0", + "GET / HTTP/1.1", + "GET /header HTTP/1.1", + "GET /more HTTP/1.1", + } { + got := ht.rawResponse(req) + wantStatus := fmt.Sprintf("%d %s", code, StatusText(code)) + if !strings.Contains(got, wantStatus) { + t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got) + } else if strings.Contains(got, "Content-Length") { + t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got) + } else if strings.Contains(got, "stuff") { + t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got) + } + } + } +} + +func TestContentTypeOkayOn204(t *testing.T) { ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { - if r.URL.Path == "/header" { - w.Header().Set("Content-Length", "123") + w.Header().Set("Content-Length", "123") // suppressed + w.Header().Set("Content-Type", "foo/bar") + w.WriteHeader(204) + })) + got := ht.rawResponse("GET / HTTP/1.1") + if !strings.Contains(got, "Content-Type: foo/bar") { + t.Errorf("Response = %q; want Content-Type: foo/bar", got) + } + if strings.Contains(got, "Content-Length: 123") { + t.Errorf("Response = %q; don't want a Content-Length", got) + } +} + +// Issue 6995 +// A server Handler can receive a Request, and then turn around and +// give a copy of that Request.Body out to the Transport (e.g. any +// proxy). So then two people own that Request.Body (both the server +// and the http client), and both think they can close it on failure. +// Therefore, all incoming server requests Bodies need to be thread-safe. +func TestTransportAndServerSharedBodyRace(t *testing.T) { + defer afterTest(t) + + const bodySize = 1 << 20 + + unblockBackend := make(chan bool) + backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + io.CopyN(rw, req.Body, bodySize/2) + <-unblockBackend + })) + defer backend.Close() + + backendRespc := make(chan *Response, 1) + proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if req.RequestURI == "/foo" { + rw.Write([]byte("bar")) + return } - w.WriteHeader(StatusNotModified) - if r.URL.Path == "/more" { - w.Write([]byte("stuff")) + req2, _ := NewRequest("POST", backend.URL, req.Body) + req2.ContentLength = bodySize + + bresp, err := DefaultClient.Do(req2) + if err != nil { + t.Errorf("Proxy outbound request: %v", err) + return + } + _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/4) + if err != nil { + t.Errorf("Proxy copy error: %v", err) + return } + backendRespc <- bresp // to close later + + // Try to cause a race: Both the DefaultTransport and the proxy handler's Server + // will try to read/close req.Body (aka req2.Body) + DefaultTransport.(*Transport).CancelRequest(req2) + rw.Write([]byte("OK")) + })) + defer proxy.Close() + + req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize)) + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatalf("Original request: %v", err) + } + + // Cleanup, so we don't leak goroutines. + res.Body.Close() + close(unblockBackend) + (<-backendRespc).Body.Close() +} + +// Test that a hanging Request.Body.Read from another goroutine can't +// cause the Handler goroutine's Request.Body.Close to block. +func TestRequestBodyCloseDoesntBlock(t *testing.T) { + t.Skipf("Skipping known issue; see golang.org/issue/7121") + if testing.Short() { + t.Skip("skipping in -short mode") + } + defer afterTest(t) + + readErrCh := make(chan error, 1) + errCh := make(chan error, 2) + + server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + go func(body io.Reader) { + _, err := body.Read(make([]byte, 100)) + readErrCh <- err + }(req.Body) + time.Sleep(500 * time.Millisecond) })) - for _, req := range []string{ - "GET / HTTP/1.0", - "GET /header HTTP/1.0", - "GET /more HTTP/1.0", - "GET / HTTP/1.1", - "GET /header HTTP/1.1", - "GET /more HTTP/1.1", - } { - got := ht.rawResponse(req) - if !strings.Contains(got, "304 Not Modified") { - t.Errorf("Non-304 Not Modified for %q: %s", req, got) - } else if strings.Contains(got, "Content-Length") { - t.Errorf("Got a Content-Length from %q: %s", req, got) + defer server.Close() + + closeConn := make(chan bool) + defer close(closeConn) + go func() { + conn, err := net.Dial("tcp", server.Listener.Addr().String()) + if err != nil { + errCh <- err + return + } + defer conn.Close() + _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n")) + if err != nil { + errCh <- err + return } + // And now just block, making the server block on our + // 100000 bytes of body that will never arrive. + <-closeConn + }() + select { + case err := <-readErrCh: + if err == nil { + t.Error("Read was nil. Expected error.") + } + case err := <-errCh: + t.Error(err) + case <-time.After(5 * time.Second): + t.Error("timeout") } } @@ -2073,8 +2286,8 @@ func TestResponseWriterWriteStringAllocs(t *testing.T) { w.Write([]byte("Hello world")) } })) - before := testing.AllocsPerRun(25, func() { ht.rawResponse("GET / HTTP/1.0") }) - after := testing.AllocsPerRun(25, func() { ht.rawResponse("GET /s HTTP/1.0") }) + before := testing.AllocsPerRun(50, func() { ht.rawResponse("GET / HTTP/1.0") }) + after := testing.AllocsPerRun(50, func() { ht.rawResponse("GET /s HTTP/1.0") }) if int(after) >= int(before) { t.Errorf("WriteString allocs of %v >= Write allocs of %v", after, before) } @@ -2093,6 +2306,230 @@ func TestAppendTime(t *testing.T) { } } +func TestServerConnState(t *testing.T) { + defer afterTest(t) + handler := map[string]func(w ResponseWriter, r *Request){ + "/": func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello.") + }, + "/close": func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + fmt.Fprintf(w, "Hello.") + }, + "/hijack": func(w ResponseWriter, r *Request) { + c, _, _ := w.(Hijacker).Hijack() + c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello.")) + c.Close() + }, + "/hijack-panic": func(w ResponseWriter, r *Request) { + c, _, _ := w.(Hijacker).Hijack() + c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello.")) + c.Close() + panic("intentional panic") + }, + } + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + handler[r.URL.Path](w, r) + })) + defer ts.Close() + + var mu sync.Mutex // guard stateLog and connID + var stateLog = map[int][]ConnState{} + var connID = map[net.Conn]int{} + + ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + ts.Config.ConnState = func(c net.Conn, state ConnState) { + if c == nil { + t.Errorf("nil conn seen in state %s", state) + return + } + mu.Lock() + defer mu.Unlock() + id, ok := connID[c] + if !ok { + id = len(connID) + 1 + connID[c] = id + } + stateLog[id] = append(stateLog[id], state) + } + ts.Start() + + mustGet(t, ts.URL+"/") + mustGet(t, ts.URL+"/close") + + mustGet(t, ts.URL+"/") + mustGet(t, ts.URL+"/", "Connection", "close") + + mustGet(t, ts.URL+"/hijack") + mustGet(t, ts.URL+"/hijack-panic") + + // New->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + } + + // New->Active->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil { + t.Fatal(err) + } + c.Close() + } + + // New->Idle->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil { + t.Fatal(err) + } + res, err := ReadResponse(bufio.NewReader(c), nil) + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatal(err) + } + c.Close() + } + + want := map[int][]ConnState{ + 1: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed}, + 2: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed}, + 3: []ConnState{StateNew, StateActive, StateHijacked}, + 4: []ConnState{StateNew, StateActive, StateHijacked}, + 5: []ConnState{StateNew, StateClosed}, + 6: []ConnState{StateNew, StateActive, StateClosed}, + 7: []ConnState{StateNew, StateActive, StateIdle, StateClosed}, + } + logString := func(m map[int][]ConnState) string { + var b bytes.Buffer + for id, l := range m { + fmt.Fprintf(&b, "Conn %d: ", id) + for _, s := range l { + fmt.Fprintf(&b, "%s ", s) + } + b.WriteString("\n") + } + return b.String() + } + + for i := 0; i < 5; i++ { + time.Sleep(time.Duration(i) * 50 * time.Millisecond) + mu.Lock() + match := reflect.DeepEqual(stateLog, want) + mu.Unlock() + if match { + return + } + } + + mu.Lock() + t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want)) + mu.Unlock() +} + +func mustGet(t *testing.T, url string, headers ...string) { + req, err := NewRequest("GET", url, nil) + if err != nil { + t.Fatal(err) + } + for len(headers) > 0 { + req.Header.Add(headers[0], headers[1]) + headers = headers[2:] + } + res, err := DefaultClient.Do(req) + if err != nil { + t.Errorf("Error fetching %s: %v", url, err) + return + } + _, err = ioutil.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + t.Errorf("Error reading %s: %v", url, err) + } +} + +func TestServerKeepAlivesEnabled(t *testing.T) { + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts.Config.SetKeepAlivesEnabled(false) + ts.Start() + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if !res.Close { + t.Errorf("Body.Close == false; want true") + } +} + +// golang.org/issue/7856 +func TestServerEmptyBodyRace(t *testing.T) { + defer afterTest(t) + var n int32 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + atomic.AddInt32(&n, 1) + })) + defer ts.Close() + var wg sync.WaitGroup + const reqs = 20 + for i := 0; i < reqs; i++ { + wg.Add(1) + go func() { + defer wg.Done() + res, err := Get(ts.URL) + if err != nil { + t.Error(err) + return + } + defer res.Body.Close() + _, err = io.Copy(ioutil.Discard, res.Body) + if err != nil { + t.Error(err) + return + } + }() + } + wg.Wait() + if got := atomic.LoadInt32(&n); got != reqs { + t.Errorf("handler ran %d times; want %d", got, reqs) + } +} + +func TestServerConnStateNew(t *testing.T) { + sawNew := false // if the test is buggy, we'll race on this variable. + srv := &Server{ + ConnState: func(c net.Conn, state ConnState) { + if state == StateNew { + sawNew = true // testing that this write isn't racy + } + }, + Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}), // irrelevant + } + srv.Serve(&oneConnListener{ + conn: &rwTestConn{ + Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"), + Writer: ioutil.Discard, + }, + }) + if !sawNew { // testing that this read isn't racy + t.Error("StateNew not seen") + } +} + func BenchmarkClientServer(b *testing.B) { b.ReportAllocs() b.StopTimer() @@ -2108,6 +2545,7 @@ func BenchmarkClientServer(b *testing.B) { b.Fatal("Get:", err) } all, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { b.Fatal("ReadAll:", err) } @@ -2128,41 +2566,33 @@ func BenchmarkClientServerParallel64(b *testing.B) { benchmarkClientServerParallel(b, 64) } -func benchmarkClientServerParallel(b *testing.B, conc int) { +func benchmarkClientServerParallel(b *testing.B, parallelism int) { b.ReportAllocs() - b.StopTimer() ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") })) defer ts.Close() - b.StartTimer() - - numProcs := runtime.GOMAXPROCS(-1) * conc - var wg sync.WaitGroup - wg.Add(numProcs) - n := int32(b.N) - for p := 0; p < numProcs; p++ { - go func() { - for atomic.AddInt32(&n, -1) >= 0 { - res, err := Get(ts.URL) - if err != nil { - b.Logf("Get: %v", err) - continue - } - all, err := ioutil.ReadAll(res.Body) - if err != nil { - b.Logf("ReadAll: %v", err) - continue - } - body := string(all) - if body != "Hello world.\n" { - panic("Got body: " + body) - } + b.ResetTimer() + b.SetParallelism(parallelism) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + res, err := Get(ts.URL) + if err != nil { + b.Logf("Get: %v", err) + continue } - wg.Done() - }() - } - wg.Wait() + all, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + b.Logf("ReadAll: %v", err) + continue + } + body := string(all) + if body != "Hello world.\n" { + panic("Got body: " + body) + } + } + }) } // A benchmark for profiling the server without the HTTP client code. @@ -2187,6 +2617,7 @@ func BenchmarkServer(b *testing.B) { log.Panicf("Get: %v", err) } all, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { log.Panicf("ReadAll: %v", err) } @@ -2390,3 +2821,28 @@ Host: golang.org b.Errorf("b.N=%d but handled %d", b.N, handled) } } + +func BenchmarkServerHijack(b *testing.B) { + b.ReportAllocs() + req := reqBytes(`GET / HTTP/1.1 +Host: golang.org +`) + h := HandlerFunc(func(w ResponseWriter, r *Request) { + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + panic(err) + } + conn.Close() + }) + conn := &rwTestConn{ + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + for i := 0; i < b.N; i++ { + conn.Reader = bytes.NewReader(req) + ln.conn = conn + Serve(ln, h) + <-conn.closec + } +} diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index 0e46863d5..eae097eb8 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -22,6 +22,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" ) @@ -138,6 +139,7 @@ func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { buf = c.buf c.rwc = nil c.buf = nil + c.setState(rwc, StateHijacked) return } @@ -435,56 +437,52 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { return c, nil } -// TODO: use a sync.Cache instead var ( - bufioReaderCache = make(chan *bufio.Reader, 4) - bufioWriterCache2k = make(chan *bufio.Writer, 4) - bufioWriterCache4k = make(chan *bufio.Writer, 4) + bufioReaderPool sync.Pool + bufioWriter2kPool sync.Pool + bufioWriter4kPool sync.Pool ) -func bufioWriterCache(size int) chan *bufio.Writer { +func bufioWriterPool(size int) *sync.Pool { switch size { case 2 << 10: - return bufioWriterCache2k + return &bufioWriter2kPool case 4 << 10: - return bufioWriterCache4k + return &bufioWriter4kPool } return nil } func newBufioReader(r io.Reader) *bufio.Reader { - select { - case p := <-bufioReaderCache: - p.Reset(r) - return p - default: - return bufio.NewReader(r) + if v := bufioReaderPool.Get(); v != nil { + br := v.(*bufio.Reader) + br.Reset(r) + return br } + return bufio.NewReader(r) } func putBufioReader(br *bufio.Reader) { br.Reset(nil) - select { - case bufioReaderCache <- br: - default: - } + bufioReaderPool.Put(br) } func newBufioWriterSize(w io.Writer, size int) *bufio.Writer { - select { - case p := <-bufioWriterCache(size): - p.Reset(w) - return p - default: - return bufio.NewWriterSize(w, size) + pool := bufioWriterPool(size) + if pool != nil { + if v := pool.Get(); v != nil { + bw := v.(*bufio.Writer) + bw.Reset(w) + return bw + } } + return bufio.NewWriterSize(w, size) } func putBufioWriter(bw *bufio.Writer) { bw.Reset(nil) - select { - case bufioWriterCache(bw.Available()) <- bw: - default: + if pool := bufioWriterPool(bw.Available()); pool != nil { + pool.Put(bw) } } @@ -500,6 +498,10 @@ func (srv *Server) maxHeaderBytes() int { return DefaultMaxHeaderBytes } +func (srv *Server) initialLimitedReaderSize() int64 { + return int64(srv.maxHeaderBytes()) + 4096 // bufio slop +} + // wrapper around io.ReaderCloser which on first read, sends an // HTTP/1.1 100 Continue header type expectContinueReader struct { @@ -570,7 +572,7 @@ func (c *conn) readRequest() (w *response, err error) { }() } - c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ + c.lr.N = c.server.initialLimitedReaderSize() var req *Request if req, err = ReadRequest(c.buf.Reader); err != nil { if c.lr.N == 0 { @@ -618,11 +620,11 @@ const maxPostHandlerReadBytes = 256 << 10 func (w *response) WriteHeader(code int) { if w.conn.hijacked() { - log.Print("http: response.WriteHeader on hijacked connection") + w.conn.server.logf("http: response.WriteHeader on hijacked connection") return } if w.wroteHeader { - log.Print("http: multiple response.WriteHeader calls") + w.conn.server.logf("http: multiple response.WriteHeader calls") return } w.wroteHeader = true @@ -637,7 +639,7 @@ func (w *response) WriteHeader(code int) { if err == nil && v >= 0 { w.contentLength = v } else { - log.Printf("http: invalid Content-Length of %q", cl) + w.conn.server.logf("http: invalid Content-Length of %q", cl) w.handlerHeader.Del("Content-Length") } } @@ -707,6 +709,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { cw.wroteHeader = true w := cw.res + keepAlivesEnabled := w.conn.server.doKeepAlives() isHEAD := w.req.Method == "HEAD" // header is written out to w.conn.buf below. Depending on the @@ -739,7 +742,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // response header and this is our first (and last) write, set // it, even to zero. This helps HTTP/1.0 clients keep their // "keep-alive" connections alive. - // Exceptions: 304 responses never get Content-Length, and if + // Exceptions: 304/204/1xx responses never get Content-Length, and if // it was a HEAD request, we don't know the difference between // 0 actual bytes and 0 bytes because the handler noticed it // was a HEAD request and chose not to write anything. So for @@ -747,14 +750,14 @@ func (cw *chunkWriter) writeHeader(p []byte) { // write non-zero bytes. If it's actually 0 bytes and the // handler never looked at the Request.Method, we just don't // send a Content-Length header. - if w.handlerDone && w.status != StatusNotModified && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { + if w.handlerDone && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { w.contentLength = int64(len(p)) setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) } // If this was an HTTP/1.0 request with keep-alive and we sent a // Content-Length back, we can make this a keep-alive response ... - if w.req.wantsHttp10KeepAlive() { + if w.req.wantsHttp10KeepAlive() && keepAlivesEnabled { sentLength := header.get("Content-Length") != "" if sentLength && header.get("Connection") == "keep-alive" { w.closeAfterReply = false @@ -773,7 +776,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { w.closeAfterReply = true } - if header.get("Connection") == "close" { + if header.get("Connection") == "close" || !keepAlivesEnabled { w.closeAfterReply = true } @@ -796,18 +799,16 @@ func (cw *chunkWriter) writeHeader(p []byte) { } code := w.status - if code == StatusNotModified { - // Must not have body. - // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" - for _, k := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - delHeader(k) - } - } else { + if bodyAllowedForStatus(code) { // If no content type, apply sniffing algorithm to body. _, haveType := header["Content-Type"] if !haveType { setHeader.contentType = DetectContentType(p) } + } else { + for _, k := range suppressedHeaders(code) { + delHeader(k) + } } if _, ok := header["Date"]; !ok { @@ -819,13 +820,13 @@ func (cw *chunkWriter) writeHeader(p []byte) { if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter // For now just ignore the Content-Length. - log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", + w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", te, w.contentLength) delHeader("Content-Length") hasCL = false } - if w.req.Method == "HEAD" || code == StatusNotModified { + if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) { // do nothing } else if code == StatusNoContent { delHeader("Transfer-Encoding") @@ -855,7 +856,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { return } - if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") { + if w.closeAfterReply && (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) { delHeader("Connection") if w.req.ProtoAtLeast(1, 1) { setHeader.connection = "close" @@ -919,7 +920,7 @@ func (w *response) bodyAllowed() bool { if !w.wroteHeader { panic("") } - return w.status != StatusNotModified + return bodyAllowedForStatus(w.status) } // The Life Of A Write is like this: @@ -965,7 +966,7 @@ func (w *response) WriteString(data string) (n int, err error) { // either dataB or dataS is non-zero. func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { if w.conn.hijacked() { - log.Print("http: response.Write on hijacked connection") + w.conn.server.logf("http: response.Write on hijacked connection") return 0, ErrHijacked } if !w.wroteHeader { @@ -1001,11 +1002,10 @@ func (w *response) finishRequest() { w.cw.close() w.conn.buf.Flush() - // Close the body, unless we're about to close the whole TCP connection - // anyway. - if !w.closeAfterReply { - w.req.Body.Close() - } + // Close the body (regardless of w.closeAfterReply) so we can + // re-use its bufio.Reader later safely. + w.req.Body.Close() + if w.req.MultipartForm != nil { w.req.MultipartForm.RemoveAll() } @@ -1084,17 +1084,25 @@ func validNPN(proto string) bool { return true } +func (c *conn) setState(nc net.Conn, state ConnState) { + if hook := c.server.ConnState; hook != nil { + hook(nc, state) + } +} + // Serve a new connection. func (c *conn) serve() { + origConn := c.rwc // copy it before it's set nil on Close or Hijack defer func() { if err := recover(); err != nil { - const size = 4096 + const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] - log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) + c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) } if !c.hijacked() { c.close() + c.setState(origConn, StateClosed) } }() @@ -1106,6 +1114,7 @@ func (c *conn) serve() { c.rwc.SetWriteDeadline(time.Now().Add(d)) } if err := tlsConn.Handshake(); err != nil { + c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err) return } c.tlsState = new(tls.ConnectionState) @@ -1121,6 +1130,10 @@ func (c *conn) serve() { for { w, err := c.readRequest() + if c.lr.N != c.server.initialLimitedReaderSize() { + // If we read any bytes off the wire, we're active. + c.setState(c.rwc, StateActive) + } if err != nil { if err == errTooLarge { // Their HTTP client may or may not be @@ -1143,16 +1156,10 @@ func (c *conn) serve() { // Expect 100 Continue support req := w.req if req.expectsContinue() { - if req.ProtoAtLeast(1, 1) { + if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 { // Wrap the Body reader with one that replies on the connection req.Body = &expectContinueReader{readCloser: req.Body, resp: w} } - if req.ContentLength == 0 { - w.Header().Set("Connection", "close") - w.WriteHeader(StatusBadRequest) - w.finishRequest() - break - } req.Header.Del("Expect") } else if req.Header.get("Expect") != "" { w.sendExpectationFailed() @@ -1175,6 +1182,7 @@ func (c *conn) serve() { } break } + c.setState(c.rwc, StateIdle) } } @@ -1202,7 +1210,14 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if w.wroteHeader { w.cw.flush() } - return w.conn.hijack() + // Release the bufioWriter that writes to the chunk writer, it is not + // used after a connection has been hijacked. + rwc, buf, err = w.conn.hijack() + if err == nil { + putBufioWriter(w.w) + w.w = nil + } + return rwc, buf, err } func (w *response) CloseNotify() <-chan bool { @@ -1562,6 +1577,7 @@ func Serve(l net.Listener, handler Handler) error { } // A Server defines parameters for running an HTTP server. +// The zero value for Server is a valid configuration. type Server struct { Addr string // TCP address to listen on, ":http" if empty Handler Handler // handler to invoke, http.DefaultServeMux if nil @@ -1578,6 +1594,66 @@ type Server struct { // and RemoteAddr if not already set. The connection is // automatically closed when the function returns. TLSNextProto map[string]func(*Server, *tls.Conn, Handler) + + // ConnState specifies an optional callback function that is + // called when a client connection changes state. See the + // ConnState type and associated constants for details. + ConnState func(net.Conn, ConnState) + + // ErrorLog specifies an optional logger for errors accepting + // connections and unexpected behavior from handlers. + // If nil, logging goes to os.Stderr via the log package's + // standard logger. + ErrorLog *log.Logger + + disableKeepAlives int32 // accessed atomically. +} + +// A ConnState represents the state of a client connection to a server. +// It's used by the optional Server.ConnState hook. +type ConnState int + +const ( + // StateNew represents a new connection that is expected to + // send a request immediately. Connections begin at this + // state and then transition to either StateActive or + // StateClosed. + StateNew ConnState = iota + + // StateActive represents a connection that has read 1 or more + // bytes of a request. The Server.ConnState hook for + // StateActive fires before the request has entered a handler + // and doesn't fire again until the request has been + // handled. After the request is handled, the state + // transitions to StateClosed, StateHijacked, or StateIdle. + StateActive + + // StateIdle represents a connection that has finished + // handling a request and is in the keep-alive state, waiting + // for a new request. Connections transition from StateIdle + // to either StateActive or StateClosed. + StateIdle + + // StateHijacked represents a hijacked connection. + // This is a terminal state. It does not transition to StateClosed. + StateHijacked + + // StateClosed represents a closed connection. + // This is a terminal state. Hijacked connections do not + // transition to StateClosed. + StateClosed +) + +var stateName = map[ConnState]string{ + StateNew: "new", + StateActive: "active", + StateIdle: "idle", + StateHijacked: "hijacked", + StateClosed: "closed", +} + +func (c ConnState) String() string { + return stateName[c] } // serverHandler delegates to either the server's Handler or @@ -1605,11 +1681,11 @@ func (srv *Server) ListenAndServe() error { if addr == "" { addr = ":http" } - l, e := net.Listen("tcp", addr) - if e != nil { - return e + ln, err := net.Listen("tcp", addr) + if err != nil { + return err } - return srv.Serve(l) + return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) } // Serve accepts incoming connections on the Listener l, creating a @@ -1630,7 +1706,7 @@ func (srv *Server) Serve(l net.Listener) error { if max := 1 * time.Second; tempDelay > max { tempDelay = max } - log.Printf("http: Accept error: %v; retrying in %v", e, tempDelay) + srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay) time.Sleep(tempDelay) continue } @@ -1641,10 +1717,35 @@ func (srv *Server) Serve(l net.Listener) error { if err != nil { continue } + c.setState(c.rwc, StateNew) // before Serve can return go c.serve() } } +func (s *Server) doKeepAlives() bool { + return atomic.LoadInt32(&s.disableKeepAlives) == 0 +} + +// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. +// By default, keep-alives are always enabled. Only very +// resource-constrained environments or servers in the process of +// shutting down should disable them. +func (s *Server) SetKeepAlivesEnabled(v bool) { + if v { + atomic.StoreInt32(&s.disableKeepAlives, 0) + } else { + atomic.StoreInt32(&s.disableKeepAlives, 1) + } +} + +func (s *Server) logf(format string, args ...interface{}) { + if s.ErrorLog != nil { + s.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + // ListenAndServe listens on the TCP network address addr // and then calls Serve with handler to handle requests // on incoming connections. Handler is typically nil, @@ -1739,12 +1840,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { return err } - conn, err := net.Listen("tcp", addr) + ln, err := net.Listen("tcp", addr) if err != nil { return err } - tlsListener := tls.NewListener(conn, config) + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) return srv.Serve(tlsListener) } @@ -1834,6 +1935,24 @@ func (tw *timeoutWriter) WriteHeader(code int) { tw.w.WriteHeader(code) } +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + // globalOptionsHandler responds to "OPTIONS *" requests. type globalOptionsHandler struct{} @@ -1850,17 +1969,24 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { } } +type eofReaderWithWriteTo struct{} + +func (eofReaderWithWriteTo) WriteTo(io.Writer) (int64, error) { return 0, nil } +func (eofReaderWithWriteTo) Read([]byte) (int, error) { return 0, io.EOF } + // eofReader is a non-nil io.ReadCloser that always returns EOF. -// It embeds a *strings.Reader so it still has a WriteTo method -// and io.Copy won't need a buffer. +// It has a WriteTo method so io.Copy won't need a buffer. var eofReader = &struct { - *strings.Reader + eofReaderWithWriteTo io.Closer }{ - strings.NewReader(""), + eofReaderWithWriteTo{}, ioutil.NopCloser(nil), } +// Verify that an io.Copy from an eofReader won't require a buffer. +var _ io.WriterTo = eofReader + // initNPNRequest is an HTTP handler that initializes certain // uninitialized fields in its *Request. Such partially-initialized // Requests come from NPN protocol handlers. diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go index bacd83732..7f6368652 100644 --- a/src/pkg/net/http/transfer.go +++ b/src/pkg/net/http/transfer.go @@ -12,10 +12,20 @@ import ( "io" "io/ioutil" "net/textproto" + "sort" "strconv" "strings" + "sync" ) +type errorReader struct { + err error +} + +func (r *errorReader) Read(p []byte) (n int, err error) { + return 0, r.err +} + // transferWriter inspects the fields of a user-supplied Request or Response, // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. @@ -52,14 +62,17 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { if t.ContentLength == 0 { // Test to see if it's actually zero or just unset. var buf [1]byte - n, _ := io.ReadFull(t.Body, buf[:]) - if n == 1 { + n, rerr := io.ReadFull(t.Body, buf[:]) + if rerr != nil && rerr != io.EOF { + t.ContentLength = -1 + t.Body = &errorReader{rerr} + } else if n == 1 { // Oh, guess there is data in this Body Reader after all. // The ContentLength field just wasn't set. // Stich the Body back together again, re-attaching our // consumed byte. t.ContentLength = -1 - t.Body = io.MultiReader(bytes.NewBuffer(buf[:]), t.Body) + t.Body = io.MultiReader(bytes.NewReader(buf[:]), t.Body) } else { // Body is actually empty. t.Body = nil @@ -131,11 +144,10 @@ func (t *transferWriter) shouldSendContentLength() bool { return false } -func (t *transferWriter) WriteHeader(w io.Writer) (err error) { +func (t *transferWriter) WriteHeader(w io.Writer) error { if t.Close { - _, err = io.WriteString(w, "Connection: close\r\n") - if err != nil { - return + if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { + return err } } @@ -143,43 +155,44 @@ func (t *transferWriter) WriteHeader(w io.Writer) (err error) { // function of the sanitized field triple (Body, ContentLength, // TransferEncoding) if t.shouldSendContentLength() { - io.WriteString(w, "Content-Length: ") - _, err = io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n") - if err != nil { - return + if _, err := io.WriteString(w, "Content-Length: "); err != nil { + return err + } + if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil { + return err } } else if chunked(t.TransferEncoding) { - _, err = io.WriteString(w, "Transfer-Encoding: chunked\r\n") - if err != nil { - return + if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil { + return err } } // Write Trailer header if t.Trailer != nil { - // TODO: At some point, there should be a generic mechanism for - // writing long headers, using HTTP line splitting - io.WriteString(w, "Trailer: ") - needComma := false + keys := make([]string, 0, len(t.Trailer)) for k := range t.Trailer { k = CanonicalHeaderKey(k) switch k { case "Transfer-Encoding", "Trailer", "Content-Length": return &badStringError{"invalid Trailer key", k} } - if needComma { - io.WriteString(w, ",") + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + // TODO: could do better allocation-wise here, but trailers are rare, + // so being lazy for now. + if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil { + return err } - io.WriteString(w, k) - needComma = true } - _, err = io.WriteString(w, "\r\n") } - return + return nil } -func (t *transferWriter) WriteBody(w io.Writer) (err error) { +func (t *transferWriter) WriteBody(w io.Writer) error { + var err error var ncopy int64 // Write body @@ -216,11 +229,16 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { // TODO(petar): Place trailer writer code here. if chunked(t.TransferEncoding) { + // Write Trailer header + if t.Trailer != nil { + if err := t.Trailer.Write(w); err != nil { + return err + } + } // Last chunk, empty trailer _, err = io.WriteString(w, "\r\n") } - - return + return err } type transferReader struct { @@ -252,6 +270,22 @@ func bodyAllowedForStatus(status int) bool { return true } +var ( + suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"} + suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"} +) + +func suppressedHeaders(status int) []string { + switch { + case status == 304: + // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" + return suppressedHeaders304 + case !bodyAllowedForStatus(status): + return suppressedHeadersNoBody + } + return nil +} + // msg is *Request or *Response. func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t := &transferReader{RequestMethod: "GET"} @@ -331,17 +365,17 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { if noBodyExpected(t.RequestMethod) { t.Body = eofReader } else { - t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + t.Body = &body{src: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} } case realLength == 0: t.Body = eofReader case realLength > 0: - t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} + t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header if t.Close { // Close semantics (i.e. HTTP/1.0) - t.Body = &body{Reader: r, closing: t.Close} + t.Body = &body{src: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) t.Body = eofReader @@ -498,7 +532,7 @@ func fixTrailer(header Header, te []string) (Header, error) { case "Transfer-Encoding", "Trailer", "Content-Length": return nil, &badStringError{"bad trailer key", key} } - trailer.Del(key) + trailer[key] = nil } if len(trailer) == 0 { return nil, nil @@ -514,11 +548,13 @@ func fixTrailer(header Header, te []string) (Header, error) { // Close ensures that the body has been fully read // and then reads the trailer if necessary. type body struct { - io.Reader + src io.Reader hdr interface{} // non-nil (Response or Request) value means read trailer r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? - closed bool + + mu sync.Mutex // guards closed, and calls to Read and Close + closed bool } // ErrBodyReadAfterClose is returned when reading a Request or Response @@ -528,10 +564,17 @@ type body struct { var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") func (b *body) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() if b.closed { return 0, ErrBodyReadAfterClose } - n, err = b.Reader.Read(p) + return b.readLocked(p) +} + +// Must hold b.mu. +func (b *body) readLocked(p []byte) (n int, err error) { + n, err = b.src.Read(p) if err == io.EOF { // Chunked case. Read the trailer. @@ -543,12 +586,23 @@ func (b *body) Read(p []byte) (n int, err error) { } else { // If the server declared the Content-Length, our body is a LimitedReader // and we need to check whether this EOF arrived early. - if lr, ok := b.Reader.(*io.LimitedReader); ok && lr.N > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 { err = io.ErrUnexpectedEOF } } } + // If we can return an EOF here along with the read data, do + // so. This is optional per the io.Reader contract, but doing + // so helps the HTTP transport code recycle its connection + // earlier (since it will see this EOF itself), even if the + // client doesn't do future reads or Close. + if err == nil && n > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 { + err = io.EOF + } + } + return n, err } @@ -610,14 +664,26 @@ func (b *body) readTrailer() error { } switch rr := b.hdr.(type) { case *Request: - rr.Trailer = Header(hdr) + mergeSetHeader(&rr.Trailer, Header(hdr)) case *Response: - rr.Trailer = Header(hdr) + mergeSetHeader(&rr.Trailer, Header(hdr)) } return nil } +func mergeSetHeader(dst *Header, src Header) { + if *dst == nil { + *dst = src + return + } + for k, vv := range src { + (*dst)[k] = vv + } +} + func (b *body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() if b.closed { return nil } @@ -629,12 +695,25 @@ func (b *body) Close() error { default: // Fully consume the body, which will also lead to us reading // the trailer headers after the body, if present. - _, err = io.Copy(ioutil.Discard, b) + _, err = io.Copy(ioutil.Discard, bodyLocked{b}) } b.closed = true return err } +// bodyLocked is a io.Reader reading from a *body when its mutex is +// already held. +type bodyLocked struct { + b *body +} + +func (bl bodyLocked) Read(p []byte) (n int, err error) { + if bl.b.closed { + return 0, ErrBodyReadAfterClose + } + return bl.b.readLocked(p) +} + // parseContentLength trims whitespace from s and returns -1 if no value // is set, or the value if it's >= 0. func parseContentLength(cl string) (int64, error) { diff --git a/src/pkg/net/http/transfer_test.go b/src/pkg/net/http/transfer_test.go index 8627a374c..48cd540b9 100644 --- a/src/pkg/net/http/transfer_test.go +++ b/src/pkg/net/http/transfer_test.go @@ -6,15 +6,16 @@ package http import ( "bufio" + "io" "strings" "testing" ) func TestBodyReadBadTrailer(t *testing.T) { b := &body{ - Reader: strings.NewReader("foobar"), - hdr: true, // force reading the trailer - r: bufio.NewReader(strings.NewReader("")), + src: strings.NewReader("foobar"), + hdr: true, // force reading the trailer + r: bufio.NewReader(strings.NewReader("")), } buf := make([]byte, 7) n, err := b.Read(buf[:3]) @@ -35,3 +36,29 @@ func TestBodyReadBadTrailer(t *testing.T) { t.Errorf("final Read was successful (%q), expected error from trailer read", got) } } + +func TestFinalChunkedBodyReadEOF(t *testing.T) { + res, err := ReadResponse(bufio.NewReader(strings.NewReader( + "HTTP/1.1 200 OK\r\n"+ + "Transfer-Encoding: chunked\r\n"+ + "\r\n"+ + "0a\r\n"+ + "Body here\n\r\n"+ + "09\r\n"+ + "continued\r\n"+ + "0\r\n"+ + "\r\n")), nil) + if err != nil { + t.Fatal(err) + } + want := "Body here\ncontinued" + buf := make([]byte, len(want)) + n, err := res.Body.Read(buf) + if n != len(want) || err != io.EOF { + t.Logf("body = %#v", res.Body) + t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want)) + } + if string(buf) != want { + t.Errorf("buf = %q; want %q", buf, want) + } +} diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index f6871afac..b1cc632a7 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -30,7 +30,14 @@ import ( // and caches them for reuse by subsequent calls. It uses HTTP proxies // as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and // $no_proxy) environment variables. -var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment} +var DefaultTransport RoundTripper = &Transport{ + Proxy: ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, +} // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. @@ -40,13 +47,13 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - idleMu sync.Mutex - idleConn map[string][]*persistConn - idleConnCh map[string]chan *persistConn - reqMu sync.Mutex - reqConn map[*Request]*persistConn - altMu sync.RWMutex - altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper + idleMu sync.Mutex + idleConn map[connectMethodKey][]*persistConn + idleConnCh map[connectMethodKey]chan *persistConn + reqMu sync.Mutex + reqCanceler map[*Request]func() + altMu sync.RWMutex + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -63,6 +70,10 @@ type Transport struct { // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config + // TLSHandshakeTimeout specifies the maximum amount of time waiting to + // wait for a TLS handshake. Zero means no timeout. + TLSHandshakeTimeout time.Duration + // DisableKeepAlives, if true, prevents re-use of TCP connections // between different HTTP requests. DisableKeepAlives bool @@ -98,8 +109,11 @@ type Transport struct { // An error is returned if the proxy environment is invalid. // A nil URL and nil error are returned if no proxy is defined in the // environment, or a proxy should not be used for the given request. +// +// As a special case, if req.URL.Host is "localhost" (with or without +// a port number), then a nil URL and nil error will be returned. func ProxyFromEnvironment(req *Request) (*url.URL, error) { - proxy := getenvEitherCase("HTTP_PROXY") + proxy := httpProxyEnv.Get() if proxy == "" { return nil, nil } @@ -149,9 +163,11 @@ func (tr *transportRequest) extraHeaders() Header { // and redirects), see Get, Post, and the Client type. func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { if req.URL == nil { + req.closeBody() return nil, errors.New("http: nil Request.URL") } if req.Header == nil { + req.closeBody() return nil, errors.New("http: nil Request.Header") } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { @@ -162,16 +178,19 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { } t.altMu.RUnlock() if rt == nil { + req.closeBody() return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} } return rt.RoundTrip(req) } if req.URL.Host == "" { + req.closeBody() return nil, errors.New("http: no Host in request URL") } treq := &transportRequest{Request: req} cm, err := t.connectMethodForRequest(treq) if err != nil { + req.closeBody() return nil, err } @@ -179,8 +198,10 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { // host (for http or https), the http proxy, or the http proxy // pre-CONNECTed to https server. In any case, we'll be ready // to send it requests. - pconn, err := t.getConn(cm) + pconn, err := t.getConn(req, cm) if err != nil { + t.setReqCanceler(req, nil) + req.closeBody() return nil, err } @@ -218,9 +239,6 @@ func (t *Transport) CloseIdleConnections() { t.idleConn = nil t.idleConnCh = nil t.idleMu.Unlock() - if m == nil { - return - } for _, conns := range m { for _, pconn := range conns { pconn.close() @@ -232,10 +250,10 @@ func (t *Transport) CloseIdleConnections() { // connection. func (t *Transport) CancelRequest(req *Request) { t.reqMu.Lock() - pc := t.reqConn[req] + cancel := t.reqCanceler[req] t.reqMu.Unlock() - if pc != nil { - pc.conn.Close() + if cancel != nil { + cancel() } } @@ -243,24 +261,49 @@ func (t *Transport) CancelRequest(req *Request) { // Private implementation past this point. // -func getenvEitherCase(k string) string { - if v := os.Getenv(strings.ToUpper(k)); v != "" { - return v +var ( + httpProxyEnv = &envOnce{ + names: []string{"HTTP_PROXY", "http_proxy"}, } - return os.Getenv(strings.ToLower(k)) + noProxyEnv = &envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +type envOnce struct { + names []string + once sync.Once + val string +} + +func (e *envOnce) Get() string { + e.once.Do(e.init) + return e.val } -func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, error) { - cm := &connectMethod{ - targetScheme: treq.URL.Scheme, - targetAddr: canonicalAddr(treq.URL), +func (e *envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } } +} + +// reset is used by tests +func (e *envOnce) reset() { + e.once = sync.Once{} + e.val = "" +} + +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) if t.Proxy != nil { - var err error cm.proxyURL, err = t.Proxy(treq.Request) - if err != nil { - return nil, err - } } return cm, nil } @@ -316,7 +359,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { } } if t.idleConn == nil { - t.idleConn = make(map[string][]*persistConn) + t.idleConn = make(map[connectMethodKey][]*persistConn) } if len(t.idleConn[key]) >= max { t.idleMu.Unlock() @@ -336,7 +379,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { // getIdleConnCh returns a channel to receive and return idle // persistent connection for the given connectMethod. // It may return nil, if persistent connections are not being used. -func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { +func (t *Transport) getIdleConnCh(cm connectMethod) chan *persistConn { if t.DisableKeepAlives { return nil } @@ -344,7 +387,7 @@ func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { t.idleMu.Lock() defer t.idleMu.Unlock() if t.idleConnCh == nil { - t.idleConnCh = make(map[string]chan *persistConn) + t.idleConnCh = make(map[connectMethodKey]chan *persistConn) } ch, ok := t.idleConnCh[key] if !ok { @@ -354,7 +397,7 @@ func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { return ch } -func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { +func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) { key := cm.key() t.idleMu.Lock() defer t.idleMu.Unlock() @@ -373,7 +416,7 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { // 2 or more cached connections; pop last // TODO: queue? pconn = pconns[len(pconns)-1] - t.idleConn[key] = pconns[0 : len(pconns)-1] + t.idleConn[key] = pconns[:len(pconns)-1] } if !pconn.isBroken() { return @@ -381,16 +424,16 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { } } -func (t *Transport) setReqConn(r *Request, pc *persistConn) { +func (t *Transport) setReqCanceler(r *Request, fn func()) { t.reqMu.Lock() defer t.reqMu.Unlock() - if t.reqConn == nil { - t.reqConn = make(map[*Request]*persistConn) + if t.reqCanceler == nil { + t.reqCanceler = make(map[*Request]func()) } - if pc != nil { - t.reqConn[r] = pc + if fn != nil { + t.reqCanceler[r] = fn } else { - delete(t.reqConn, r) + delete(t.reqCanceler, r) } } @@ -405,7 +448,7 @@ func (t *Transport) dial(network, addr string) (c net.Conn, err error) { // specified in the connectMethod. This includes doing a proxy CONNECT // and/or setting up TLS. If this doesn't return an error, the persistConn // is ready to write requests to. -func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { +func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) { if pc := t.getIdleConn(cm); pc != nil { return pc, nil } @@ -415,6 +458,16 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { err error } dialc := make(chan dialRes) + + handlePendingDial := func() { + if v := <-dialc; v.err == nil { + t.putIdleConn(v.pc) + } + } + + cancelc := make(chan struct{}) + t.setReqCanceler(req, func() { close(cancelc) }) + go func() { pc, err := t.dialConn(cm) dialc <- dialRes{pc, err} @@ -431,16 +484,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { // else's dial that they didn't use. // But our dial is still going, so give it away // when it finishes: - go func() { - if v := <-dialc; v.err == nil { - t.putIdleConn(v.pc) - } - }() + go handlePendingDial() return pc, nil + case <-cancelc: + go handlePendingDial() + return nil, errors.New("net/http: request canceled while waiting for connection") } } -func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { +func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { conn, err := t.dial("tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { @@ -452,12 +504,13 @@ func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { pa := cm.proxyAuth() pconn := &persistConn{ - t: t, - cacheKey: cm.key(), - conn: conn, - reqch: make(chan requestAndChan, 50), - writech: make(chan writeRequest, 50), - closech: make(chan struct{}), + t: t, + cacheKey: cm.key(), + conn: conn, + reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), + closech: make(chan struct{}), + writeErrCh: make(chan error, 1), } switch { @@ -511,19 +564,38 @@ func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { cfg = &clone } } - conn = tls.Client(conn, cfg) - if err = conn.(*tls.Conn).Handshake(); err != nil { + plainConn := conn + tlsConn := tls.Client(plainConn, cfg) + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() return nil, err } if !cfg.InsecureSkipVerify { - if err = conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + plainConn.Close() return nil, err } } - pconn.conn = conn + cs := tlsConn.ConnectionState() + pconn.tlsState = &cs + pconn.conn = tlsConn } - pconn.br = bufio.NewReader(pconn.conn) + pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF}) pconn.bw = bufio.NewWriter(pconn.conn) go pconn.readLoop() go pconn.writeLoop() @@ -550,7 +622,7 @@ func useProxy(addr string) bool { } } - no_proxy := getenvEitherCase("NO_PROXY") + no_proxy := noProxyEnv.Get() if no_proxy == "*" { return false } @@ -590,8 +662,8 @@ func useProxy(addr string) bool { // // Cache key form Description // ----------------- ------------------------- -// ||http|foo.com http directly to server, no proxy -// ||https|foo.com https directly to server, no proxy +// |http|foo.com http directly to server, no proxy +// |https|foo.com https directly to server, no proxy // http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com // http://proxy.com|http http to proxy, http to anywhere after that // @@ -603,20 +675,20 @@ type connectMethod struct { targetAddr string // Not used if proxy + http targetScheme (4th example in table) } -func (ck *connectMethod) key() string { - return ck.String() // TODO: use a struct type instead -} - -func (ck *connectMethod) String() string { +func (cm *connectMethod) key() connectMethodKey { proxyStr := "" - targetAddr := ck.targetAddr - if ck.proxyURL != nil { - proxyStr = ck.proxyURL.String() - if ck.targetScheme == "http" { + targetAddr := cm.targetAddr + if cm.proxyURL != nil { + proxyStr = cm.proxyURL.String() + if cm.targetScheme == "http" { targetAddr = "" } } - return strings.Join([]string{proxyStr, ck.targetScheme, targetAddr}, "|") + return connectMethodKey{ + proxy: proxyStr, + scheme: cm.targetScheme, + addr: targetAddr, + } } // addr returns the first hop "host:port" to which we need to TCP connect. @@ -637,22 +709,41 @@ func (cm *connectMethod) tlsHost() string { return h } +// connectMethodKey is the map key version of connectMethod, with a +// stringified proxy URL (or the empty string) instead of a pointer to +// a URL. +type connectMethodKey struct { + proxy, scheme, addr string +} + +func (k connectMethodKey) String() string { + // Only used by tests. + return fmt.Sprintf("%s|%s|%s", k.proxy, k.scheme, k.addr) +} + // persistConn wraps a connection, usually a persistent one // (but may be used for non-keep-alive requests as well) type persistConn struct { t *Transport - cacheKey string // its connectMethod.String() + cacheKey connectMethodKey conn net.Conn - closed bool // whether conn has been closed + tlsState *tls.ConnectionState br *bufio.Reader // from conn + sawEOF bool // whether we've seen EOF from conn; owned by readLoop bw *bufio.Writer // to conn reqch chan requestAndChan // written by roundTrip; read by readLoop writech chan writeRequest // written by roundTrip; read by writeLoop - closech chan struct{} // broadcast close when readLoop (TCP connection) closes + closech chan struct{} // closed when conn closed isProxy bool + // writeErrCh passes the request write error (usually nil) + // from the writeLoop goroutine to the readLoop which passes + // it off to the res.Body reader, which then uses it to decide + // whether or not a connection can be reused. Issue 7569. + writeErrCh chan error - lk sync.Mutex // guards following 3 fields + lk sync.Mutex // guards following fields numExpectedResponses int + closed bool // whether conn has been closed broken bool // an error has happened on this connection; marked broken so it's not reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the @@ -660,6 +751,7 @@ type persistConn struct { mutateHeaderFunc func(Header) } +// isBroken reports whether this connection is in a known broken state. func (pc *persistConn) isBroken() bool { pc.lk.Lock() b := pc.broken @@ -667,6 +759,10 @@ func (pc *persistConn) isBroken() bool { return b } +func (pc *persistConn) cancelRequest() { + pc.conn.Close() +} + var remoteSideClosedFunc func(error) bool // or nil to use default func remoteSideClosed(err error) bool { @@ -680,7 +776,6 @@ func remoteSideClosed(err error) bool { } func (pc *persistConn) readLoop() { - defer close(pc.closech) alive := true for alive { @@ -688,12 +783,14 @@ func (pc *persistConn) readLoop() { pc.lk.Lock() if pc.numExpectedResponses == 0 { - pc.closeLocked() - pc.lk.Unlock() - if len(pb) > 0 { - log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", - string(pb), err) + if !pc.closed { + pc.closeLocked() + if len(pb) > 0 { + log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", + string(pb), err) + } } + pc.lk.Unlock() return } pc.lk.Unlock() @@ -712,6 +809,11 @@ func (pc *persistConn) readLoop() { resp, err = ReadResponse(pc.br, rc.req) } } + + if resp != nil { + resp.TLS = pc.tlsState + } + hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 if err != nil { @@ -721,13 +823,7 @@ func (pc *persistConn) readLoop() { resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") resp.ContentLength = -1 - gzReader, zerr := gzip.NewReader(resp.Body) - if zerr != nil { - pc.close() - err = zerr - } else { - resp.Body = &readerAndCloser{gzReader, resp.Body} - } + resp.Body = &gzipReader{body: resp.Body} } resp.Body = &bodyEOFSignal{body: resp.Body} } @@ -750,24 +846,18 @@ func (pc *persistConn) readLoop() { return nil } resp.Body.(*bodyEOFSignal).fn = func(err error) { - alive1 := alive - if err != nil { - alive1 = false - } - if alive1 && !pc.t.putIdleConn(pc) { - alive1 = false - } - if !alive1 || pc.isBroken() { - pc.close() - } - waitForBodyRead <- alive1 + waitForBodyRead <- alive && + err == nil && + !pc.sawEOF && + pc.wroteRequest() && + pc.t.putIdleConn(pc) } } if alive && !hasBody { - if !pc.t.putIdleConn(pc) { - alive = false - } + alive = !pc.sawEOF && + pc.wroteRequest() && + pc.t.putIdleConn(pc) } rc.ch <- responseAndError{resp, err} @@ -775,10 +865,14 @@ func (pc *persistConn) readLoop() { // Wait for the just-returned response body to be fully consumed // before we race and peek on the underlying bufio reader. if waitForBodyRead != nil { - alive = <-waitForBodyRead + select { + case alive = <-waitForBodyRead: + case <-pc.closech: + alive = false + } } - pc.t.setReqConn(rc.req, nil) + pc.t.setReqCanceler(rc.req, nil) if !alive { pc.close() @@ -800,14 +894,44 @@ func (pc *persistConn) writeLoop() { } if err != nil { pc.markBroken() + wr.req.Request.closeBody() } - wr.ch <- err + pc.writeErrCh <- err // to the body reader, which might recycle us + wr.ch <- err // to the roundTrip function case <-pc.closech: return } } } +// wroteRequest is a check before recycling a connection that the previous write +// (from writeLoop above) happened and was successful. +func (pc *persistConn) wroteRequest() bool { + select { + case err := <-pc.writeErrCh: + // Common case: the write happened well before the response, so + // avoid creating a timer. + return err == nil + default: + // Rare case: the request was written in writeLoop above but + // before it could send to pc.writeErrCh, the reader read it + // all, processed it, and called us here. In this case, give the + // write goroutine a bit of time to finish its send. + // + // Less rare case: We also get here in the legitimate case of + // Issue 7569, where the writer is still writing (or stalled), + // but the server has already replied. In this case, we don't + // want to wait too long, and we want to return false so this + // connection isn't re-used. + select { + case err := <-pc.writeErrCh: + return err == nil + case <-time.After(50 * time.Millisecond): + return false + } + } +} + type responseAndError struct { res *Response err error @@ -832,8 +956,20 @@ type writeRequest struct { ch chan<- error } +type httpError struct { + err string + timeout bool +} + +func (e *httpError) Error() string { return e.err } +func (e *httpError) Timeout() bool { return e.timeout } +func (e *httpError) Temporary() bool { return true } + +var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} +var errClosed error = &httpError{err: "net/http: transport closed before response was received"} + func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - pc.t.setReqConn(req.Request, pc) + pc.t.setReqCanceler(req.Request, pc.cancelRequest) pc.lk.Lock() pc.numExpectedResponses++ headerFn := pc.mutateHeaderFunc @@ -902,11 +1038,11 @@ WaitResponse: pconnDeadCh = nil // avoid spinning failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc case <-failTicker: - re = responseAndError{err: errors.New("net/http: transport closed before response was received")} + re = responseAndError{err: errClosed} break WaitResponse case <-respHeaderTimer: pc.close() - re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")} + re = responseAndError{err: errTimeout} break WaitResponse case re = <-resc: break WaitResponse @@ -918,7 +1054,7 @@ WaitResponse: pc.lk.Unlock() if re.err != nil { - pc.t.setReqConn(req.Request, nil) + pc.t.setReqCanceler(req.Request, nil) } return re.res, re.err } @@ -943,6 +1079,7 @@ func (pc *persistConn) closeLocked() { if !pc.closed { pc.conn.Close() pc.closed = true + close(pc.closech) } pc.mutateHeaderFunc = nil } @@ -1025,7 +1162,47 @@ func (es *bodyEOFSignal) condfn(err error) { es.fn = nil } +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type gzipReader struct { + body io.ReadCloser // underlying Response.Body + zr io.Reader // lazily-initialized gzip reader +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} + type readerAndCloser struct { io.Reader io.Closer } + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + +type noteEOFReader struct { + r io.Reader + sawEOF *bool +} + +func (nr noteEOFReader) Read(p []byte) (n int, err error) { + n, err = nr.r.Read(p) + if err == io.EOF { + *nr.sawEOF = true + } + return +} diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index e4df30a98..964ca0fca 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -11,9 +11,12 @@ import ( "bytes" "compress/gzip" "crypto/rand" + "crypto/tls" + "errors" "fmt" "io" "io/ioutil" + "log" "net" "net/http" . "net/http" @@ -54,21 +57,21 @@ func (c *testCloseConn) Close() error { // been closed. type testConnSet struct { t *testing.T + mu sync.Mutex // guards closed and list closed map[net.Conn]bool list []net.Conn // in order created - mutex sync.Mutex } func (tcs *testConnSet) insert(c net.Conn) { - tcs.mutex.Lock() - defer tcs.mutex.Unlock() + tcs.mu.Lock() + defer tcs.mu.Unlock() tcs.closed[c] = false tcs.list = append(tcs.list, c) } func (tcs *testConnSet) remove(c net.Conn) { - tcs.mutex.Lock() - defer tcs.mutex.Unlock() + tcs.mu.Lock() + defer tcs.mu.Unlock() tcs.closed[c] = true } @@ -91,11 +94,19 @@ func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, e } func (tcs *testConnSet) check(t *testing.T) { - tcs.mutex.Lock() - defer tcs.mutex.Unlock() - - for i, c := range tcs.list { - if !tcs.closed[c] { + tcs.mu.Lock() + defer tcs.mu.Unlock() + for i := 4; i >= 0; i-- { + for i, c := range tcs.list { + if tcs.closed[c] { + continue + } + if i != 0 { + tcs.mu.Unlock() + time.Sleep(50 * time.Millisecond) + tcs.mu.Lock() + continue + } t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) } } @@ -271,6 +282,58 @@ func TestTransportIdleCacheKeys(t *testing.T) { } } +// Tests that the HTTP transport re-uses connections when a client +// reads to the end of a response Body without closing it. +func TestTransportReadToEndReusesConn(t *testing.T) { + defer afterTest(t) + const msg = "foobar" + + var addrSeen map[string]int + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + addrSeen[r.RemoteAddr]++ + if r.URL.Path == "/chunked/" { + w.WriteHeader(200) + w.(http.Flusher).Flush() + } else { + w.Header().Set("Content-Type", strconv.Itoa(len(msg))) + w.WriteHeader(200) + } + w.Write([]byte(msg)) + })) + defer ts.Close() + + buf := make([]byte, len(msg)) + + for pi, path := range []string{"/content-length/", "/chunked/"} { + wantLen := []int{len(msg), -1}[pi] + addrSeen = make(map[string]int) + for i := 0; i < 3; i++ { + res, err := http.Get(ts.URL + path) + if err != nil { + t.Errorf("Get %s: %v", path, err) + continue + } + // We want to close this body eventually (before the + // defer afterTest at top runs), but not before the + // len(addrSeen) check at the bottom of this test, + // since Closing this early in the loop would risk + // making connections be re-used for the wrong reason. + defer res.Body.Close() + + if res.ContentLength != int64(wantLen) { + t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) + } + n, err := res.Body.Read(buf) + if n != len(msg) || err != io.EOF { + t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) + } + } + if len(addrSeen) != 1 { + t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) + } + } +} + func TestTransportMaxPerHostIdleConns(t *testing.T) { defer afterTest(t) resch := make(chan string) @@ -295,10 +358,11 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { resp, err := c.Get(ts.URL) if err != nil { t.Error(err) + return } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ReadAll: %v", err) + if _, err := ioutil.ReadAll(resp.Body); err != nil { + t.Errorf("ReadAll: %v", err) + return } donech <- true } @@ -739,8 +803,38 @@ func TestTransportGzipRecursive(t *testing.T) { } } +// golang.org/issue/7750: request fails when server replies with +// a short gzip body +func TestTransportGzipShort(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write([]byte{0x1f, 0x8b}) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + _, err = ioutil.ReadAll(res.Body) + if err == nil { + t.Fatal("Expect an error from reading a body.") + } + if err != io.ErrUnexpectedEOF { + t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) + } +} + // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) gotReqCh := make(chan bool) unblockCh := make(chan bool) @@ -798,8 +892,8 @@ func TestTransportPersistConnLeak(t *testing.T) { // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. // Previously we were leaking one per numReq. - t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) if int(growth) > 5 { + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) t.Error("too many new goroutines") } } @@ -807,6 +901,9 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) @@ -1014,6 +1111,9 @@ func TestTransportConcurrency(t *testing.T) { } func TestIssue4191_InfiniteGetTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) const debug = false mux := NewServeMux() @@ -1075,6 +1175,9 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) const debug = false mux := NewServeMux() @@ -1147,9 +1250,13 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping timeout test in -short mode") } + inHandler := make(chan bool, 1) mux := NewServeMux() - mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {}) + mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { + inHandler <- true + }) mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { + inHandler <- true time.Sleep(2 * time.Second) }) ts := httptest.NewServer(mux) @@ -1172,7 +1279,27 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { } for i, tt := range tests { res, err := c.Get(ts.URL + tt.path) + select { + case <-inHandler: + case <-time.After(5 * time.Second): + t.Errorf("never entered handler for test index %d, %s", i, tt.path) + continue + } if err != nil { + uerr, ok := err.(*url.Error) + if !ok { + t.Errorf("error is not an url.Error; got: %#v", err) + continue + } + nerr, ok := uerr.Err.(net.Error) + if !ok { + t.Errorf("error does not satisfy net.Error interface; got: %#v", err) + continue + } + if !nerr.Timeout() { + t.Errorf("want timeout error; got: %q", nerr) + continue + } if strings.Contains(err.Error(), tt.wantErr) { continue } @@ -1243,6 +1370,60 @@ func TestTransportCancelRequest(t *testing.T) { } } +func TestTransportCancelRequestInDial(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + var logbuf bytes.Buffer + eventLog := log.New(&logbuf, "", 0) + + unblockDial := make(chan bool) + defer close(unblockDial) + + inDial := make(chan bool) + tr := &Transport{ + Dial: func(network, addr string) (net.Conn, error) { + eventLog.Println("dial: blocking") + inDial <- true + <-unblockDial + return nil, errors.New("nope") + }, + } + cl := &Client{Transport: tr} + gotres := make(chan bool) + req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) + go func() { + _, err := cl.Do(req) + eventLog.Printf("Get = %v", err) + gotres <- true + }() + + select { + case <-inDial: + case <-time.After(5 * time.Second): + t.Fatal("timeout; never saw blocking dial") + } + + eventLog.Printf("canceling") + tr.CancelRequest(req) + + select { + case <-gotres: + case <-time.After(5 * time.Second): + panic("hang. events are: " + logbuf.String()) + } + + got := logbuf.String() + want := `dial: blocking +canceling +Get = Get http://something.no-network.tld/: net/http: request canceled while waiting for connection +` + if got != want { + t.Errorf("Got events:\n%s\nWant:\n%s", got, want) + } +} + // golang.org/issue/3672 -- Client can't close HTTP stream // Calling Close on a Response.Body used to just read until EOF. // Now it actually closes the TCP connection. @@ -1283,7 +1464,7 @@ func TestTransportCloseResponseBody(t *testing.T) { t.Fatal(err) } if !bytes.Equal(buf, want) { - t.Errorf("read %q; want %q", buf, want) + t.Fatalf("read %q; want %q", buf, want) } didClose := make(chan error, 1) go func() { @@ -1372,8 +1553,10 @@ func TestTransportSocketLateBinding(t *testing.T) { dialGate := make(chan bool, 1) tr := &Transport{ Dial: func(n, addr string) (net.Conn, error) { - <-dialGate - return net.Dial(n, addr) + if <-dialGate { + return net.Dial(n, addr) + } + return nil, errors.New("manually closed") }, DisableKeepAlives: false, } @@ -1408,7 +1591,7 @@ func TestTransportSocketLateBinding(t *testing.T) { t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) } barRes.Body.Close() - dialGate <- true + dialGate <- false } // Issue 2184 @@ -1559,13 +1742,11 @@ var proxyFromEnvTests = []proxyFromEnvTest{ } func TestProxyFromEnvironment(t *testing.T) { - os.Setenv("HTTP_PROXY", "") - os.Setenv("http_proxy", "") - os.Setenv("NO_PROXY", "") - os.Setenv("no_proxy", "") + ResetProxyEnv() for _, tt := range proxyFromEnvTests { os.Setenv("HTTP_PROXY", tt.env) os.Setenv("NO_PROXY", tt.noenv) + ResetCachedEnvironment() reqURL := tt.req if reqURL == "" { reqURL = "http://example.com" @@ -1643,6 +1824,308 @@ func TestTransportClosesRequestBody(t *testing.T) { } } +func TestTransportTLSHandshakeTimeout(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping in short mode") + } + ln := newLocalListener(t) + defer ln.Close() + testdonec := make(chan struct{}) + defer close(testdonec) + + go func() { + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + <-testdonec + c.Close() + }() + + getdonec := make(chan struct{}) + go func() { + defer close(getdonec) + tr := &Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial("tcp", ln.Addr().String()) + }, + TLSHandshakeTimeout: 250 * time.Millisecond, + } + cl := &Client{Transport: tr} + _, err := cl.Get("https://dummy.tld/") + if err == nil { + t.Error("expected error") + return + } + ue, ok := err.(*url.Error) + if !ok { + t.Errorf("expected url.Error; got %#v", err) + return + } + ne, ok := ue.Err.(net.Error) + if !ok { + t.Errorf("expected net.Error; got %#v", err) + return + } + if !ne.Timeout() { + t.Errorf("expected timeout error; got %v", err) + } + if !strings.Contains(err.Error(), "handshake timeout") { + t.Errorf("expected 'handshake timeout' in error; got %v", err) + } + }() + select { + case <-getdonec: + case <-time.After(5 * time.Second): + t.Error("test timeout; TLS handshake hung?") + } +} + +// Trying to repro golang.org/issue/3514 +func TestTLSServerClosesConnection(t *testing.T) { + defer afterTest(t) + if runtime.GOOS == "windows" { + t.Skip("skipping flaky test on Windows; golang.org/issue/7634") + } + closedc := make(chan bool, 1) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if strings.Contains(r.URL.Path, "/keep-alive-then-die") { + conn, _, _ := w.(Hijacker).Hijack() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) + conn.Close() + closedc <- true + return + } + fmt.Fprintf(w, "hello") + })) + defer ts.Close() + tr := &Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + defer tr.CloseIdleConnections() + client := &Client{Transport: tr} + + var nSuccess = 0 + var errs []error + const trials = 20 + for i := 0; i < trials; i++ { + tr.CloseIdleConnections() + res, err := client.Get(ts.URL + "/keep-alive-then-die") + if err != nil { + t.Fatal(err) + } + <-closedc + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != "foo" { + t.Errorf("Got %q, want foo", slurp) + } + + // Now try again and see if we successfully + // pick a new connection. + res, err = client.Get(ts.URL + "/") + if err != nil { + errs = append(errs, err) + continue + } + slurp, err = ioutil.ReadAll(res.Body) + if err != nil { + errs = append(errs, err) + continue + } + nSuccess++ + } + if nSuccess > 0 { + t.Logf("successes = %d of %d", nSuccess, trials) + } else { + t.Errorf("All runs failed:") + } + for _, err := range errs { + t.Logf(" err: %v", err) + } +} + +// byteFromChanReader is an io.Reader that reads a single byte at a +// time from the channel. When the channel is closed, the reader +// returns io.EOF. +type byteFromChanReader chan byte + +func (c byteFromChanReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + b, ok := <-c + if !ok { + return 0, io.EOF + } + p[0] = b + return 1, nil +} + +// Verifies that the Transport doesn't reuse a connection in the case +// where the server replies before the request has been fully +// written. We still honor that reply (see TestIssue3595), but don't +// send future requests on the connection because it's then in a +// questionable state. +// golang.org/issue/7569 +func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { + defer afterTest(t) + var sconn struct { + sync.Mutex + c net.Conn + } + var getOkay bool + closeConn := func() { + sconn.Lock() + defer sconn.Unlock() + if sconn.c != nil { + sconn.c.Close() + sconn.c = nil + if !getOkay { + t.Logf("Closed server connection") + } + } + } + defer closeConn() + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method == "GET" { + io.WriteString(w, "bar") + return + } + conn, _, _ := w.(Hijacker).Hijack() + sconn.Lock() + sconn.c = conn + sconn.Unlock() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive + go io.Copy(ioutil.Discard, conn) + })) + defer ts.Close() + tr := &Transport{} + defer tr.CloseIdleConnections() + client := &Client{Transport: tr} + + const bodySize = 256 << 10 + finalBit := make(byteFromChanReader, 1) + req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) + req.ContentLength = bodySize + res, err := client.Do(req) + if err := wantBody(res, err, "foo"); err != nil { + t.Errorf("POST response: %v", err) + } + donec := make(chan bool) + go func() { + defer close(donec) + res, err = client.Get(ts.URL) + if err := wantBody(res, err, "bar"); err != nil { + t.Errorf("GET response: %v", err) + return + } + getOkay = true // suppress test noise + }() + time.AfterFunc(5*time.Second, closeConn) + select { + case <-donec: + finalBit <- 'x' // unblock the writeloop of the first Post + close(finalBit) + case <-time.After(7 * time.Second): + t.Fatal("timeout waiting for GET request to finish") + } +} + +type errorReader struct { + err error +} + +func (e errorReader) Read(p []byte) (int, error) { return 0, e.err } + +type closerFunc func() error + +func (f closerFunc) Close() error { return f() } + +// Issue 6981 +func TestTransportClosesBodyOnError(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7782") + } + defer afterTest(t) + readBody := make(chan error, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := ioutil.ReadAll(r.Body) + readBody <- err + })) + defer ts.Close() + fakeErr := errors.New("fake error") + didClose := make(chan bool, 1) + req, _ := NewRequest("POST", ts.URL, struct { + io.Reader + io.Closer + }{ + io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}), + closerFunc(func() error { + select { + case didClose <- true: + default: + } + return nil + }), + }) + res, err := DefaultClient.Do(req) + if res != nil { + defer res.Body.Close() + } + if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { + t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) + } + select { + case err := <-readBody: + if err == nil { + t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") + } + case <-time.After(5 * time.Second): + t.Error("timeout waiting for server handler to complete") + } + select { + case <-didClose: + default: + t.Errorf("didn't see Body.Close") + } +} + +func wantBody(res *http.Response, err error, want string) error { + if err != nil { + return err + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("error reading body: %v", err) + } + if string(slurp) != want { + return fmt.Errorf("body = %q; want %q", slurp, want) + } + if err := res.Body.Close(); err != nil { + return fmt.Errorf("body Close = %v", err) + } + return nil +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + type countCloseReader struct { n *int io.Reader diff --git a/src/pkg/net/interface.go b/src/pkg/net/interface.go index 0713e9cd6..2e9f1ebc6 100644 --- a/src/pkg/net/interface.go +++ b/src/pkg/net/interface.go @@ -7,11 +7,11 @@ package net import "errors" var ( - errInvalidInterface = errors.New("net: invalid interface") - errInvalidInterfaceIndex = errors.New("net: invalid interface index") - errInvalidInterfaceName = errors.New("net: invalid interface name") - errNoSuchInterface = errors.New("net: no such interface") - errNoSuchMulticastInterface = errors.New("net: no such multicast interface") + errInvalidInterface = errors.New("invalid network interface") + errInvalidInterfaceIndex = errors.New("invalid network interface index") + errInvalidInterfaceName = errors.New("invalid network interface name") + errNoSuchInterface = errors.New("no such network interface") + errNoSuchMulticastInterface = errors.New("no such multicast network interface") ) // Interface represents a mapping between network interface name diff --git a/src/pkg/net/interface_linux.go b/src/pkg/net/interface_linux.go index 1207c0f26..1115d0fc4 100644 --- a/src/pkg/net/interface_linux.go +++ b/src/pkg/net/interface_linux.go @@ -45,15 +45,41 @@ loop: return ift, nil } +const ( + // See linux/if_arp.h. + // Note that Linux doesn't support IPv4 over IPv6 tunneling. + sysARPHardwareIPv4IPv4 = 768 // IPv4 over IPv4 tunneling + sysARPHardwareIPv6IPv6 = 769 // IPv6 over IPv6 tunneling + sysARPHardwareIPv6IPv4 = 776 // IPv6 over IPv4 tunneling + sysARPHardwareGREIPv4 = 778 // any over GRE over IPv4 tunneling + sysARPHardwareGREIPv6 = 823 // any over GRE over IPv6 tunneling +) + func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) *Interface { ifi := &Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)} for _, a := range attrs { switch a.Attr.Type { case syscall.IFLA_ADDRESS: + // We never return any /32 or /128 IP address + // prefix on any IP tunnel interface as the + // hardware address. + switch len(a.Value) { + case IPv4len: + switch ifim.Type { + case sysARPHardwareIPv4IPv4, sysARPHardwareGREIPv4, sysARPHardwareIPv6IPv4: + continue + } + case IPv6len: + switch ifim.Type { + case sysARPHardwareIPv6IPv6, sysARPHardwareGREIPv6: + continue + } + } var nonzero bool for _, b := range a.Value { if b != 0 { nonzero = true + break } } if nonzero { @@ -147,19 +173,31 @@ loop: } func newAddr(ifi *Interface, ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRouteAttr) Addr { - for _, a := range attrs { - if ifi.Flags&FlagPointToPoint != 0 && a.Attr.Type == syscall.IFA_LOCAL || - ifi.Flags&FlagPointToPoint == 0 && a.Attr.Type == syscall.IFA_ADDRESS { - switch ifam.Family { - case syscall.AF_INET: - return &IPNet{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv4len)} - case syscall.AF_INET6: - ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv6len)} - copy(ifa.IP, a.Value[:]) - return ifa + var ipPointToPoint bool + // Seems like we need to make sure whether the IP interface + // stack consists of IP point-to-point numbered or unnumbered + // addressing over point-to-point link encapsulation. + if ifi.Flags&FlagPointToPoint != 0 { + for _, a := range attrs { + if a.Attr.Type == syscall.IFA_LOCAL { + ipPointToPoint = true + break } } } + for _, a := range attrs { + if ipPointToPoint && a.Attr.Type == syscall.IFA_ADDRESS || !ipPointToPoint && a.Attr.Type == syscall.IFA_LOCAL { + continue + } + switch ifam.Family { + case syscall.AF_INET: + return &IPNet{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv4len)} + case syscall.AF_INET6: + ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv6len)} + copy(ifa.IP, a.Value[:]) + return ifa + } + } return nil } diff --git a/src/pkg/net/interface_stub.go b/src/pkg/net/interface_stub.go index a4eb731da..c38fb7f76 100644 --- a/src/pkg/net/interface_stub.go +++ b/src/pkg/net/interface_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build plan9 +// +build nacl plan9 solaris package net diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go index fd6a7d4ee..0582009b8 100644 --- a/src/pkg/net/ip.go +++ b/src/pkg/net/ip.go @@ -623,6 +623,9 @@ func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) { for k := ellipsis + n - 1; k >= ellipsis; k-- { ip[k] = 0 } + } else if ellipsis >= 0 { + // Ellipsis must represent at least one 0 group. + return nil, zone } return ip, zone } diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go index 26b53729b..ffeb9d315 100644 --- a/src/pkg/net/ip_test.go +++ b/src/pkg/net/ip_test.go @@ -25,6 +25,7 @@ var parseIPTests = []struct { {"fe80::1%lo0", nil}, {"fe80::1%911", nil}, {"", nil}, + {"a1:a2:a3:a4::b1:b2:b3:b4", nil}, // Issue 6628 } func TestParseIP(t *testing.T) { diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go index ea183f1d3..0632dafc6 100644 --- a/src/pkg/net/ipraw_test.go +++ b/src/pkg/net/ipraw_test.go @@ -247,7 +247,7 @@ var ipConnLocalNameTests = []struct { func TestIPConnLocalName(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "nacl", "plan9", "windows": t.Skipf("skipping test on %q", runtime.GOOS) default: if os.Getuid() != 0 { @@ -277,7 +277,7 @@ func TestIPConnRemoteName(t *testing.T) { } } - raddr := &IPAddr{IP: IPv4(127, 0, 0, 10).To4()} + raddr := &IPAddr{IP: IPv4(127, 0, 0, 1).To4()} c, err := DialIP("ip:tcp", &IPAddr{IP: IPv4(127, 0, 0, 1)}, raddr) if err != nil { t.Fatalf("DialIP failed: %v", err) diff --git a/src/pkg/net/iprawsock_posix.go b/src/pkg/net/iprawsock_posix.go index 722853257..bbb3f3ed6 100644 --- a/src/pkg/net/iprawsock_posix.go +++ b/src/pkg/net/iprawsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows package net @@ -19,7 +19,7 @@ import ( // that you do not uses these methods if it is important to receive a // full packet. // -// The Go 1 compatibliity guidelines make it impossible for us to +// The Go 1 compatibility guidelines make it impossible for us to // change the behavior of these methods; use Read or ReadMsgIP // instead. @@ -79,7 +79,7 @@ func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) { // TODO(cw,rsc): consider using readv if we know the family // type to avoid the header trim/copy var addr *IPAddr - n, sa, err := c.fd.ReadFrom(b) + n, sa, err := c.fd.readFrom(b) switch sa := sa.(type) { case *syscall.SockaddrInet4: addr = &IPAddr{IP: sa.Addr[0:]} @@ -112,7 +112,7 @@ func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err return 0, 0, 0, nil, syscall.EINVAL } var sa syscall.Sockaddr - n, oobn, flags, sa, err = c.fd.ReadMsg(b, oob) + n, oobn, flags, sa, err = c.fd.readMsg(b, oob) switch sa := sa.(type) { case *syscall.SockaddrInet4: addr = &IPAddr{IP: sa.Addr[0:]} @@ -133,6 +133,9 @@ func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL } + if c.fd.isConnected { + return 0, &OpError{Op: "write", Net: c.fd.net, Addr: addr, Err: ErrWriteToConnected} + } if addr == nil { return 0, &OpError{Op: "write", Net: c.fd.net, Addr: nil, Err: errMissingAddress} } @@ -140,7 +143,7 @@ func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { if err != nil { return 0, &OpError{"write", c.fd.net, addr, err} } - return c.fd.WriteTo(b, sa) + return c.fd.writeTo(b, sa) } // WriteTo implements the PacketConn WriteTo method. @@ -162,6 +165,9 @@ func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error if !c.ok() { return 0, 0, syscall.EINVAL } + if c.fd.isConnected { + return 0, 0, &OpError{Op: "write", Net: c.fd.net, Addr: addr, Err: ErrWriteToConnected} + } if addr == nil { return 0, 0, &OpError{Op: "write", Net: c.fd.net, Addr: nil, Err: errMissingAddress} } @@ -169,7 +175,7 @@ func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error if err != nil { return 0, 0, &OpError{"write", c.fd.net, addr, err} } - return c.fd.WriteMsg(b, oob, sa) + return c.fd.writeMsg(b, oob, sa) } // DialIP connects to the remote address raddr on the network protocol diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go index 8b586ef7c..dda857803 100644 --- a/src/pkg/net/ipsock.go +++ b/src/pkg/net/ipsock.go @@ -16,7 +16,7 @@ var ( // networking functionality. supportsIPv4 bool - // supportsIPv6 reports whether the platfrom supports IPv6 + // supportsIPv6 reports whether the platform supports IPv6 // networking functionality. supportsIPv6 bool @@ -207,7 +207,7 @@ missingBrackets: } func splitHostZone(s string) (host, zone string) { - // The IPv6 scoped addressing zone identifer starts after the + // The IPv6 scoped addressing zone identifier starts after the // last percent sign. if i := last(s, '%'); i > 0 { host, zone = s[:i], s[i+1:] @@ -232,7 +232,7 @@ func JoinHostPort(host, port string) string { // address or a DNS name and returns an internet protocol family // address. It returns a list that contains a pair of different // address family addresses when addr is a DNS name and the name has -// mutiple address family records. The result contains at least one +// multiple address family records. The result contains at least one // address when error is nil. func resolveInternetAddr(net, addr string, deadline time.Time) (netaddr, error) { var ( diff --git a/src/pkg/net/ipsock_plan9.go b/src/pkg/net/ipsock_plan9.go index fcec4164f..94ceea31b 100644 --- a/src/pkg/net/ipsock_plan9.go +++ b/src/pkg/net/ipsock_plan9.go @@ -12,19 +12,45 @@ import ( "syscall" ) +func probe(filename, query string) bool { + var file *file + var err error + if file, err = open(filename); err != nil { + return false + } + + r := false + for line, ok := file.readLine(); ok && !r; line, ok = file.readLine() { + f := getFields(line) + if len(f) < 3 { + continue + } + for i := 0; i < len(f); i++ { + if query == f[i] { + r = true + break + } + } + } + file.close() + return r +} + func probeIPv4Stack() bool { - // TODO(mikio): implement this when Plan 9 supports IPv6-only - // kernel. - return true + return probe(netdir+"/iproute", "4i") } // probeIPv6Stack returns two boolean values. If the first boolean // value is true, kernel supports basic IPv6 functionality. If the // second boolean value is true, kernel supports IPv6 IPv4-mapping. func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { - // TODO(mikio): implement this once Plan 9 gets an IPv6 - // protocol stack implementation. - return false, false + // Plan 9 uses IPv6 natively, see ip(3). + r := probe(netdir+"/iproute", "6i") + v := false + if r { + v = probe(netdir+"/iproute", "4i") + } + return r, v } // parsePlan9Addr parses address of the form [ip!]port (e.g. 127.0.0.1!80). @@ -34,12 +60,12 @@ func parsePlan9Addr(s string) (ip IP, iport int, err error) { if i >= 0 { addr = ParseIP(s[:i]) if addr == nil { - return nil, 0, errors.New("net: parsing IP failed") + return nil, 0, errors.New("parsing IP failed") } } p, _, ok := dtoi(s[i+1:], 0) if !ok { - return nil, 0, errors.New("net: parsing port failed") + return nil, 0, errors.New("parsing port failed") } if p < 0 || p > 0xFFFF { return nil, 0, &AddrError{"invalid port", string(p)} @@ -133,18 +159,18 @@ func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) { f.Close() return nil, &OpError{"dial", f.Name(), raddr, err} } - data, err := os.OpenFile("/net/"+proto+"/"+name+"/data", os.O_RDWR, 0) + data, err := os.OpenFile(netdir+"/"+proto+"/"+name+"/data", os.O_RDWR, 0) if err != nil { f.Close() return nil, &OpError{"dial", net, raddr, err} } - laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") + laddr, err = readPlan9Addr(proto, netdir+"/"+proto+"/"+name+"/local") if err != nil { data.Close() f.Close() return nil, &OpError{"dial", proto, raddr, err} } - return newFD(proto, name, f, data, laddr, raddr), nil + return newFD(proto, name, f, data, laddr, raddr) } func listenPlan9(net string, laddr Addr) (fd *netFD, err error) { @@ -158,20 +184,24 @@ func listenPlan9(net string, laddr Addr) (fd *netFD, err error) { f.Close() return nil, &OpError{"announce", proto, laddr, err} } - laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") + laddr, err = readPlan9Addr(proto, netdir+"/"+proto+"/"+name+"/local") if err != nil { f.Close() return nil, &OpError{Op: "listen", Net: net, Err: err} } - return newFD(proto, name, f, nil, laddr, nil), nil + return newFD(proto, name, f, nil, laddr, nil) } -func (l *netFD) netFD() *netFD { - return newFD(l.proto, l.name, l.ctl, l.data, l.laddr, l.raddr) +func (l *netFD) netFD() (*netFD, error) { + return newFD(l.proto, l.n, l.ctl, l.data, l.laddr, l.raddr) } func (l *netFD) acceptPlan9() (fd *netFD, err error) { defer func() { netErr(err) }() + if err := l.readLock(); err != nil { + return nil, err + } + defer l.readUnlock() f, err := os.Open(l.dir + "/listen") if err != nil { return nil, &OpError{"accept", l.dir + "/listen", l.laddr, err} @@ -183,16 +213,16 @@ func (l *netFD) acceptPlan9() (fd *netFD, err error) { return nil, &OpError{"accept", l.dir + "/listen", l.laddr, err} } name := string(buf[:n]) - data, err := os.OpenFile("/net/"+l.proto+"/"+name+"/data", os.O_RDWR, 0) + data, err := os.OpenFile(netdir+"/"+l.proto+"/"+name+"/data", os.O_RDWR, 0) if err != nil { f.Close() return nil, &OpError{"accept", l.proto, l.laddr, err} } - raddr, err := readPlan9Addr(l.proto, "/net/"+l.proto+"/"+name+"/remote") + raddr, err := readPlan9Addr(l.proto, netdir+"/"+l.proto+"/"+name+"/remote") if err != nil { data.Close() f.Close() return nil, &OpError{"accept", l.proto, l.laddr, err} } - return newFD(l.proto, name, f, data, l.laddr, raddr), nil + return newFD(l.proto, name, f, data, l.laddr, raddr) } diff --git a/src/pkg/net/ipsock_posix.go b/src/pkg/net/ipsock_posix.go index a83e52561..2ba4c8efd 100644 --- a/src/pkg/net/ipsock_posix.go +++ b/src/pkg/net/ipsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows // Internet protocol family sockets for POSIX @@ -40,12 +40,13 @@ func probeIPv4Stack() bool { func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { var probes = []struct { laddr TCPAddr + value int ok bool }{ // IPv6 communication capability - {TCPAddr{IP: ParseIP("::1")}, false}, + {laddr: TCPAddr{IP: ParseIP("::1")}, value: 1}, // IPv6 IPv4-mapped address communication capability - {TCPAddr{IP: IPv4(127, 0, 0, 1)}, false}, + {laddr: TCPAddr{IP: IPv4(127, 0, 0, 1)}, value: 0}, } for i := range probes { @@ -54,7 +55,7 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { continue } defer closesocket(s) - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, probes[i].value) sa, err := probes[i].laddr.sockaddr(syscall.AF_INET6) if err != nil { continue diff --git a/src/pkg/net/lookup_plan9.go b/src/pkg/net/lookup_plan9.go index f1204a99f..b80ac10e0 100644 --- a/src/pkg/net/lookup_plan9.go +++ b/src/pkg/net/lookup_plan9.go @@ -16,6 +16,10 @@ func query(filename, query string, bufSize int) (res []string, err error) { } defer file.Close() + _, err = file.Seek(0, 0) + if err != nil { + return + } _, err = file.WriteString(query) if err != nil { return @@ -45,7 +49,7 @@ func queryCS(net, host, service string) (res []string, err error) { if host == "" { host = "*" } - return query("/net/cs", net+"!"+host+"!"+service, 128) + return query(netdir+"/cs", net+"!"+host+"!"+service, 128) } func queryCS1(net string, ip IP, port int) (clone, dest string, err error) { @@ -59,20 +63,41 @@ func queryCS1(net string, ip IP, port int) (clone, dest string, err error) { } f := getFields(lines[0]) if len(f) < 2 { - return "", "", errors.New("net: bad response from ndb/cs") + return "", "", errors.New("bad response from ndb/cs") } clone, dest = f[0], f[1] return } func queryDNS(addr string, typ string) (res []string, err error) { - return query("/net/dns", addr+" "+typ, 1024) + return query(netdir+"/dns", addr+" "+typ, 1024) +} + +// toLower returns a lower-case version of in. Restricting us to +// ASCII is sufficient to handle the IP protocol names and allow +// us to not depend on the strings and unicode packages. +func toLower(in string) string { + for _, c := range in { + if 'A' <= c && c <= 'Z' { + // Has upper case; need to fix. + out := []byte(in) + for i := 0; i < len(in); i++ { + c := in[i] + if 'A' <= c && c <= 'Z' { + c += 'a' - 'A' + } + out[i] = c + } + return string(out) + } + } + return in } // lookupProtocol looks up IP protocol name and returns // the corresponding protocol number. func lookupProtocol(name string) (proto int, err error) { - lines, err := query("/net/cs", "!protocol="+name, 128) + lines, err := query(netdir+"/cs", "!protocol="+toLower(name), 128) if err != nil { return 0, err } @@ -92,12 +117,13 @@ func lookupProtocol(name string) (proto int, err error) { } func lookupHost(host string) (addrs []string, err error) { - // Use /net/cs instead of /net/dns because cs knows about + // Use netdir/cs instead of netdir/dns because cs knows about // host names in local network (e.g. from /lib/ndb/local) - lines, err := queryCS("tcp", host, "1") + lines, err := queryCS("net", host, "1") if err != nil { return } +loop: for _, line := range lines { f := getFields(line) if len(f) < 2 { @@ -110,6 +136,12 @@ func lookupHost(host string) (addrs []string, err error) { if ParseIP(addr) == nil { continue } + // only return unique addresses + for _, a := range addrs { + if a == addr { + continue loop + } + } addrs = append(addrs, addr) } return @@ -167,7 +199,7 @@ func lookupCNAME(name string) (cname string, err error) { return f[2] + ".", nil } } - return "", errors.New("net: bad response from ndb/dns") + return "", errors.New("bad response from ndb/dns") } func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { diff --git a/src/pkg/net/lookup_unix.go b/src/pkg/net/lookup_unix.go index 59e9f6321..b1d2f8f31 100644 --- a/src/pkg/net/lookup_unix.go +++ b/src/pkg/net/lookup_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris package net diff --git a/src/pkg/net/mail/message.go b/src/pkg/net/mail/message.go index dc2ab44da..ba0778caa 100644 --- a/src/pkg/net/mail/message.go +++ b/src/pkg/net/mail/message.go @@ -159,7 +159,9 @@ func (a *Address) String() string { // If every character is printable ASCII, quoting is simple. allPrintable := true for i := 0; i < len(a.Name); i++ { - if !isVchar(a.Name[i]) { + // isWSP here should actually be isFWS, + // but we don't support folding yet. + if !isVchar(a.Name[i]) && !isWSP(a.Name[i]) { allPrintable = false break } @@ -167,7 +169,7 @@ func (a *Address) String() string { if allPrintable { b := bytes.NewBufferString(`"`) for i := 0; i < len(a.Name); i++ { - if !isQtext(a.Name[i]) { + if !isQtext(a.Name[i]) && !isWSP(a.Name[i]) { b.WriteByte('\\') } b.WriteByte(a.Name[i]) @@ -361,7 +363,7 @@ func (p *addrParser) consumePhrase() (phrase string, err error) { // Ignore any error if we got at least one word. if err != nil && len(words) == 0 { debug.Printf("consumePhrase: hit err: %v", err) - return "", errors.New("mail: missing word in phrase") + return "", fmt.Errorf("mail: missing word in phrase: %v", err) } phrase = strings.Join(words, " ") return phrase, nil @@ -440,11 +442,11 @@ func (p *addrParser) len() int { func decodeRFC2047Word(s string) (string, error) { fields := strings.Split(s, "?") if len(fields) != 5 || fields[0] != "=" || fields[4] != "=" { - return "", errors.New("mail: address not RFC 2047 encoded") + return "", errors.New("address not RFC 2047 encoded") } charset, enc := strings.ToLower(fields[1]), strings.ToLower(fields[2]) if charset != "iso-8859-1" && charset != "utf-8" { - return "", fmt.Errorf("mail: charset not supported: %q", charset) + return "", fmt.Errorf("charset not supported: %q", charset) } in := bytes.NewBufferString(fields[3]) @@ -455,7 +457,7 @@ func decodeRFC2047Word(s string) (string, error) { case "q": r = qDecoder{r: in} default: - return "", fmt.Errorf("mail: RFC 2047 encoding not supported: %q", enc) + return "", fmt.Errorf("RFC 2047 encoding not supported: %q", enc) } dec, err := ioutil.ReadAll(r) @@ -535,3 +537,9 @@ func isVchar(c byte) bool { // Visible (printing) characters. return '!' <= c && c <= '~' } + +// isWSP returns true if c is a WSP (white space). +// WSP is a space or horizontal tab (RFC5234 Appendix B). +func isWSP(c byte) bool { + return c == ' ' || c == '\t' +} diff --git a/src/pkg/net/mail/message_test.go b/src/pkg/net/mail/message_test.go index 3c037f383..eb9c8cbdc 100644 --- a/src/pkg/net/mail/message_test.go +++ b/src/pkg/net/mail/message_test.go @@ -8,6 +8,7 @@ import ( "bytes" "io/ioutil" "reflect" + "strings" "testing" "time" ) @@ -116,6 +117,14 @@ func TestDateParsing(t *testing.T) { } } +func TestAddressParsingError(t *testing.T) { + const txt = "=?iso-8859-2?Q?Bogl=E1rka_Tak=E1cs?= <unknown@gmail.com>" + _, err := ParseAddress(txt) + if err == nil || !strings.Contains(err.Error(), "charset not supported") { + t.Errorf(`mail.ParseAddress(%q) err: %q, want ".*charset not supported.*"`, txt, err) + } +} + func TestAddressParsing(t *testing.T) { tests := []struct { addrsStr string @@ -277,6 +286,14 @@ func TestAddressFormatting(t *testing.T) { &Address{Name: "Böb", Address: "bob@example.com"}, `=?utf-8?q?B=C3=B6b?= <bob@example.com>`, }, + { + &Address{Name: "Bob Jane", Address: "bob@example.com"}, + `"Bob Jane" <bob@example.com>`, + }, + { + &Address{Name: "Böb Jacöb", Address: "bob@example.com"}, + `=?utf-8?q?B=C3=B6b_Jac=C3=B6b?= <bob@example.com>`, + }, } for _, test := range tests { s := test.addr.String() diff --git a/src/pkg/net/multicast_test.go b/src/pkg/net/multicast_test.go index 5660fd42f..63dbce88e 100644 --- a/src/pkg/net/multicast_test.go +++ b/src/pkg/net/multicast_test.go @@ -25,8 +25,10 @@ var ipv4MulticastListenerTests = []struct { // port. func TestIPv4MulticastListener(t *testing.T) { switch runtime.GOOS { - case "plan9": + case "nacl", "plan9": t.Skipf("skipping test on %q", runtime.GOOS) + case "solaris": + t.Skipf("skipping test on solaris, see issue 7399") } closer := func(cs []*UDPConn) { @@ -93,8 +95,10 @@ var ipv6MulticastListenerTests = []struct { // port. func TestIPv6MulticastListener(t *testing.T) { switch runtime.GOOS { - case "plan9", "solaris": + case "plan9": t.Skipf("skipping test on %q", runtime.GOOS) + case "solaris": + t.Skipf("skipping test on solaris, see issue 7399") } if !supportsIPv6 { t.Skip("ipv6 is not supported") diff --git a/src/pkg/net/net.go b/src/pkg/net/net.go index 2e6db5551..ca56af54f 100644 --- a/src/pkg/net/net.go +++ b/src/pkg/net/net.go @@ -275,7 +275,16 @@ type Listener interface { Addr() Addr } -var errMissingAddress = errors.New("missing address") +// Various errors contained in OpError. +var ( + // For connection setup and write operations. + errMissingAddress = errors.New("missing address") + + // For both read and write operations. + errTimeout error = &timeoutError{} + errClosing = errors.New("use of closed network connection") + ErrWriteToConnected = errors.New("use of WriteTo with pre-connected connection") +) // OpError is the error type usually returned by functions in the net // package. It describes the operation, network type, and address of @@ -337,10 +346,6 @@ func (e *timeoutError) Error() string { return "i/o timeout" } func (e *timeoutError) Timeout() bool { return true } func (e *timeoutError) Temporary() bool { return true } -var errTimeout error = &timeoutError{} - -var errClosing = errors.New("use of closed network connection") - type AddrError struct { Err string Addr string diff --git a/src/pkg/net/net_test.go b/src/pkg/net/net_test.go index 1320096df..bfed4d657 100644 --- a/src/pkg/net/net_test.go +++ b/src/pkg/net/net_test.go @@ -28,12 +28,14 @@ func TestShutdown(t *testing.T) { defer ln.Close() c, err := ln.Accept() if err != nil { - t.Fatalf("Accept: %v", err) + t.Errorf("Accept: %v", err) + return } var buf [10]byte n, err := c.Read(buf[:]) if n != 0 || err != io.EOF { - t.Fatalf("server Read = %d, %v; want 0, io.EOF", n, err) + t.Errorf("server Read = %d, %v; want 0, io.EOF", n, err) + return } c.Write([]byte("response")) c.Close() @@ -62,7 +64,7 @@ func TestShutdown(t *testing.T) { func TestShutdownUnix(t *testing.T) { switch runtime.GOOS { - case "windows", "plan9": + case "nacl", "plan9", "windows": t.Skipf("skipping test on %q", runtime.GOOS) } f, err := ioutil.TempFile("", "go_net_unixtest") @@ -84,12 +86,14 @@ func TestShutdownUnix(t *testing.T) { go func() { c, err := ln.Accept() if err != nil { - t.Fatalf("Accept: %v", err) + t.Errorf("Accept: %v", err) + return } var buf [10]byte n, err := c.Read(buf[:]) if n != 0 || err != io.EOF { - t.Fatalf("server Read = %d, %v; want 0, io.EOF", n, err) + t.Errorf("server Read = %d, %v; want 0, io.EOF", n, err) + return } c.Write([]byte("response")) c.Close() @@ -196,7 +200,8 @@ func TestTCPClose(t *testing.T) { go func() { c, err := Dial("tcp", l.Addr().String()) if err != nil { - t.Fatal(err) + t.Errorf("Dial: %v", err) + return } go read(c) @@ -231,12 +236,12 @@ func TestErrorNil(t *testing.T) { // Make Listen fail by relistening on the same address. l, err := Listen("tcp", "127.0.0.1:0") if err != nil { - t.Fatal("Listen 127.0.0.1:0: %v", err) + t.Fatalf("Listen 127.0.0.1:0: %v", err) } defer l.Close() l1, err := Listen("tcp", l.Addr().String()) if err == nil { - t.Fatal("second Listen %v: %v", l.Addr(), err) + t.Fatalf("second Listen %v: %v", l.Addr(), err) } if l1 != nil { t.Fatalf("Listen returned non-nil interface %T(%v) with err != nil", l1, l1) @@ -245,12 +250,12 @@ func TestErrorNil(t *testing.T) { // Make ListenPacket fail by relistening on the same address. lp, err := ListenPacket("udp", "127.0.0.1:0") if err != nil { - t.Fatal("Listen 127.0.0.1:0: %v", err) + t.Fatalf("Listen 127.0.0.1:0: %v", err) } defer lp.Close() lp1, err := ListenPacket("udp", lp.LocalAddr().String()) if err == nil { - t.Fatal("second Listen %v: %v", lp.LocalAddr(), err) + t.Fatalf("second Listen %v: %v", lp.LocalAddr(), err) } if lp1 != nil { t.Fatalf("ListenPacket returned non-nil interface %T(%v) with err != nil", lp1, lp1) diff --git a/src/pkg/net/net_windows_test.go b/src/pkg/net/net_windows_test.go index 8b1c9cdc5..2f57745e3 100644 --- a/src/pkg/net/net_windows_test.go +++ b/src/pkg/net/net_windows_test.go @@ -84,7 +84,7 @@ func TestAcceptIgnoreSomeErrors(t *testing.T) { } err = cmd.Start() if err != nil { - t.Fatalf("cmd.Start failed: %v\n%s\n", err) + t.Fatalf("cmd.Start failed: %v\n", err) } outReader := bufio.NewReader(stdout) for { @@ -107,7 +107,7 @@ func TestAcceptIgnoreSomeErrors(t *testing.T) { result := make(chan error) go func() { time.Sleep(alittle) - err = send(ln.Addr().String(), "abc") + err := send(ln.Addr().String(), "abc") if err != nil { result <- err } diff --git a/src/pkg/net/netgo_unix_test.go b/src/pkg/net/netgo_unix_test.go new file mode 100644 index 000000000..9fb2a567d --- /dev/null +++ b/src/pkg/net/netgo_unix_test.go @@ -0,0 +1,24 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !cgo netgo +// +build darwin dragonfly freebsd linux netbsd openbsd solaris + +package net + +import "testing" + +func TestGoLookupIP(t *testing.T) { + host := "localhost" + _, err, ok := cgoLookupIP(host) + if ok { + t.Errorf("cgoLookupIP must be a placeholder") + } + if err != nil { + t.Errorf("cgoLookupIP failed: %v", err) + } + if _, err := goLookupIP(host); err != nil { + t.Errorf("goLookupIP failed: %v", err) + } +} diff --git a/src/pkg/net/packetconn_test.go b/src/pkg/net/packetconn_test.go index 945003f67..b6e4e76f9 100644 --- a/src/pkg/net/packetconn_test.go +++ b/src/pkg/net/packetconn_test.go @@ -15,12 +15,6 @@ import ( "time" ) -func strfunc(s string) func() string { - return func() string { - return s - } -} - func packetConnTestData(t *testing.T, net string, i int) ([]byte, func()) { switch net { case "udp": @@ -46,7 +40,7 @@ func packetConnTestData(t *testing.T, net string, i int) ([]byte, func()) { return b, nil case "unixgram": switch runtime.GOOS { - case "plan9", "windows": + case "nacl", "plan9", "windows": return nil, func() { t.Logf("skipping %q test on %q", net, runtime.GOOS) } @@ -62,12 +56,12 @@ func packetConnTestData(t *testing.T, net string, i int) ([]byte, func()) { var packetConnTests = []struct { net string - addr1 func() string - addr2 func() string + addr1 string + addr2 string }{ - {"udp", strfunc("127.0.0.1:0"), strfunc("127.0.0.1:0")}, - {"ip:icmp", strfunc("127.0.0.1"), strfunc("127.0.0.1")}, - {"unixgram", testUnixAddr, testUnixAddr}, + {"udp", "127.0.0.1:0", "127.0.0.1:0"}, + {"ip:icmp", "127.0.0.1", "127.0.0.1"}, + {"unixgram", testUnixAddr(), testUnixAddr()}, } func TestPacketConn(t *testing.T) { @@ -88,22 +82,21 @@ func TestPacketConn(t *testing.T) { continue } - addr1, addr2 := tt.addr1(), tt.addr2() - c1, err := ListenPacket(tt.net, addr1) + c1, err := ListenPacket(tt.net, tt.addr1) if err != nil { t.Fatalf("ListenPacket failed: %v", err) } - defer closer(c1, netstr[0], addr1, addr2) + defer closer(c1, netstr[0], tt.addr1, tt.addr2) c1.LocalAddr() c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) - c2, err := ListenPacket(tt.net, addr2) + c2, err := ListenPacket(tt.net, tt.addr2) if err != nil { t.Fatalf("ListenPacket failed: %v", err) } - defer closer(c2, netstr[0], addr1, addr2) + defer closer(c2, netstr[0], tt.addr1, tt.addr2) c2.LocalAddr() c2.SetDeadline(time.Now().Add(100 * time.Millisecond)) c2.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) @@ -145,12 +138,11 @@ func TestConnAndPacketConn(t *testing.T) { continue } - addr1, addr2 := tt.addr1(), tt.addr2() - c1, err := ListenPacket(tt.net, addr1) + c1, err := ListenPacket(tt.net, tt.addr1) if err != nil { t.Fatalf("ListenPacket failed: %v", err) } - defer closer(c1, netstr[0], addr1, addr2) + defer closer(c1, netstr[0], tt.addr1, tt.addr2) c1.LocalAddr() c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) diff --git a/src/pkg/net/parse.go b/src/pkg/net/parse.go index 6056de248..ee6e7e995 100644 --- a/src/pkg/net/parse.go +++ b/src/pkg/net/parse.go @@ -67,7 +67,7 @@ func open(name string) (*file, error) { if err != nil { return nil, err } - return &file{fd, make([]byte, os.Getpagesize())[0:0], false}, nil + return &file{fd, make([]byte, 0, os.Getpagesize()), false}, nil } func byteIndex(s string, c byte) int { diff --git a/src/pkg/net/port_unix.go b/src/pkg/net/port_unix.go index 3cd9ca2aa..89558c1f0 100644 --- a/src/pkg/net/port_unix.go +++ b/src/pkg/net/port_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris // Read system port mappings from /etc/services @@ -10,12 +10,16 @@ package net import "sync" -var services map[string]map[string]int +// services contains minimal mappings between services names and port +// numbers for platforms that don't have a complete list of port numbers +// (some Solaris distros). +var services = map[string]map[string]int{ + "tcp": {"http": 80}, +} var servicesError error var onceReadServices sync.Once func readServices() { - services = make(map[string]map[string]int) var file *file if file, servicesError = open("/etc/services"); servicesError != nil { return @@ -29,7 +33,7 @@ func readServices() { if len(f) < 2 { continue } - portnet := f[1] // "tcp/80" + portnet := f[1] // "80/tcp" port, j, ok := dtoi(portnet, 0) if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' { continue diff --git a/src/pkg/net/protoconn_test.go b/src/pkg/net/protoconn_test.go index 5a8958b08..12856b6c3 100644 --- a/src/pkg/net/protoconn_test.go +++ b/src/pkg/net/protoconn_test.go @@ -19,7 +19,7 @@ import ( // also uses /tmp directory in case it is prohibited to create UNIX // sockets in TMPDIR. func testUnixAddr() string { - f, err := ioutil.TempFile("/tmp", "nettest") + f, err := ioutil.TempFile("", "nettest") if err != nil { panic(err) } @@ -236,7 +236,7 @@ func TestIPConnSpecificMethods(t *testing.T) { func TestUnixListenerSpecificMethods(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "nacl", "plan9", "windows": t.Skipf("skipping test on %q", runtime.GOOS) } @@ -278,7 +278,7 @@ func TestUnixListenerSpecificMethods(t *testing.T) { func TestUnixConnSpecificMethods(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "nacl", "plan9", "windows": t.Skipf("skipping test on %q", runtime.GOOS) } diff --git a/src/pkg/net/rpc/client.go b/src/pkg/net/rpc/client.go index c524d0a0a..21f79b068 100644 --- a/src/pkg/net/rpc/client.go +++ b/src/pkg/net/rpc/client.go @@ -39,14 +39,16 @@ type Call struct { // with a single Client, and a Client may be used by // multiple goroutines simultaneously. type Client struct { - mutex sync.Mutex // protects pending, seq, request - sending sync.Mutex + codec ClientCodec + + sending sync.Mutex + + mutex sync.Mutex // protects following request Request seq uint64 - codec ClientCodec pending map[uint64]*Call - closing bool - shutdown bool + closing bool // user has called Close + shutdown bool // server has told us to stop } // A ClientCodec implements writing of RPC requests and @@ -274,7 +276,7 @@ func Dial(network, address string) (*Client, error) { func (client *Client) Close() error { client.mutex.Lock() - if client.shutdown || client.closing { + if client.closing { client.mutex.Unlock() return ErrShutdown } diff --git a/src/pkg/net/rpc/client_test.go b/src/pkg/net/rpc/client_test.go new file mode 100644 index 000000000..bbfc1ec3a --- /dev/null +++ b/src/pkg/net/rpc/client_test.go @@ -0,0 +1,36 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +import ( + "errors" + "testing" +) + +type shutdownCodec struct { + responded chan int + closed bool +} + +func (c *shutdownCodec) WriteRequest(*Request, interface{}) error { return nil } +func (c *shutdownCodec) ReadResponseBody(interface{}) error { return nil } +func (c *shutdownCodec) ReadResponseHeader(*Response) error { + c.responded <- 1 + return errors.New("shutdownCodec ReadResponseHeader") +} +func (c *shutdownCodec) Close() error { + c.closed = true + return nil +} + +func TestCloseCodec(t *testing.T) { + codec := &shutdownCodec{responded: make(chan int)} + client := NewClientWithCodec(codec) + <-codec.responded + client.Close() + if !codec.closed { + t.Error("client.Close did not close codec") + } +} diff --git a/src/pkg/net/rpc/jsonrpc/all_test.go b/src/pkg/net/rpc/jsonrpc/all_test.go index 40d4b82d7..a433a365e 100644 --- a/src/pkg/net/rpc/jsonrpc/all_test.go +++ b/src/pkg/net/rpc/jsonrpc/all_test.go @@ -5,6 +5,7 @@ package jsonrpc import ( + "bytes" "encoding/json" "errors" "fmt" @@ -12,6 +13,7 @@ import ( "io/ioutil" "net" "net/rpc" + "strings" "testing" ) @@ -202,6 +204,39 @@ func TestMalformedOutput(t *testing.T) { } } +func TestServerErrorHasNullResult(t *testing.T) { + var out bytes.Buffer + sc := NewServerCodec(struct { + io.Reader + io.Writer + io.Closer + }{ + Reader: strings.NewReader(`{"method": "Arith.Add", "id": "123", "params": []}`), + Writer: &out, + Closer: ioutil.NopCloser(nil), + }) + r := new(rpc.Request) + if err := sc.ReadRequestHeader(r); err != nil { + t.Fatal(err) + } + const valueText = "the value we don't want to see" + const errorText = "some error" + err := sc.WriteResponse(&rpc.Response{ + ServiceMethod: "Method", + Seq: 1, + Error: errorText, + }, valueText) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(out.String(), errorText) { + t.Fatalf("Response didn't contain expected error %q: %s", errorText, &out) + } + if strings.Contains(out.String(), valueText) { + t.Errorf("Response contains both an error and value: %s", &out) + } +} + func TestUnexpectedError(t *testing.T) { cli, srv := myPipe() go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error diff --git a/src/pkg/net/rpc/jsonrpc/server.go b/src/pkg/net/rpc/jsonrpc/server.go index 16ec0fe9a..e6d37cfa6 100644 --- a/src/pkg/net/rpc/jsonrpc/server.go +++ b/src/pkg/net/rpc/jsonrpc/server.go @@ -100,7 +100,6 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error { var null = json.RawMessage([]byte("null")) func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error { - var resp serverResponse c.mutex.Lock() b, ok := c.pending[r.Seq] if !ok { @@ -114,10 +113,9 @@ func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error { // Invalid request so no id. Use JSON null. b = &null } - resp.Id = b - resp.Result = x + resp := serverResponse{Id: b} if r.Error == "" { - resp.Error = nil + resp.Result = x } else { resp.Error = r.Error } diff --git a/src/pkg/net/rpc/server.go b/src/pkg/net/rpc/server.go index 7eb2dcf5a..6b264b46b 100644 --- a/src/pkg/net/rpc/server.go +++ b/src/pkg/net/rpc/server.go @@ -217,10 +217,11 @@ func isExportedOrBuiltinType(t reflect.Type) bool { // Register publishes in the server the set of methods of the // receiver value that satisfy the following conditions: // - exported method -// - two arguments, both pointers to exported structs +// - two arguments, both of exported type +// - the second argument is a pointer // - one return value, of type error // It returns an error if the receiver is not an exported type or has -// no methods or unsuitable methods. It also logs the error using package log. +// no suitable methods. It also logs the error using package log. // The client accesses each method using a string of the form "Type.Method", // where Type is the receiver's concrete type. func (server *Server) Register(rcvr interface{}) error { diff --git a/src/pkg/net/rpc/server_test.go b/src/pkg/net/rpc/server_test.go index 3b9a88380..0dc4ddc2d 100644 --- a/src/pkg/net/rpc/server_test.go +++ b/src/pkg/net/rpc/server_test.go @@ -594,7 +594,6 @@ func TestErrorAfterClientClose(t *testing.T) { } func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { - b.StopTimer() once.Do(startServer) client, err := dial() if err != nil { @@ -604,33 +603,24 @@ func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { // Synchronous calls args := &Args{7, 8} - procs := runtime.GOMAXPROCS(-1) - N := int32(b.N) - var wg sync.WaitGroup - wg.Add(procs) - b.StartTimer() - - for p := 0; p < procs; p++ { - go func() { - reply := new(Reply) - for atomic.AddInt32(&N, -1) >= 0 { - err := client.Call("Arith.Add", args, reply) - if err != nil { - b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) - } - if reply.C != args.A+args.B { - b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) - } + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + reply := new(Reply) + for pb.Next() { + err := client.Call("Arith.Add", args, reply) + if err != nil { + b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) } - wg.Done() - }() - } - wg.Wait() + if reply.C != args.A+args.B { + b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) + } + } + }) } func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { const MaxConcurrentCalls = 100 - b.StopTimer() once.Do(startServer) client, err := dial() if err != nil { @@ -647,7 +637,7 @@ func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { wg.Add(procs) gate := make(chan bool, MaxConcurrentCalls) res := make(chan *Call, MaxConcurrentCalls) - b.StartTimer() + b.ResetTimer() for p := 0; p < procs; p++ { go func() { diff --git a/src/pkg/net/sendfile_dragonfly.go b/src/pkg/net/sendfile_dragonfly.go index a2219c163..bc88fd3b9 100644 --- a/src/pkg/net/sendfile_dragonfly.go +++ b/src/pkg/net/sendfile_dragonfly.go @@ -23,7 +23,7 @@ const maxSendfileSize int = 4 << 20 // if handled == false, sendFile performed no work. func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { // DragonFly uses 0 as the "until EOF" value. If you pass in more bytes than the - // file contains, it will loop back to the beginning ad nauseum until it's sent + // file contains, it will loop back to the beginning ad nauseam until it's sent // exactly the number of bytes told to. As such, we need to know exactly how many // bytes to send. var remain int64 = 0 diff --git a/src/pkg/net/sendfile_freebsd.go b/src/pkg/net/sendfile_freebsd.go index 42fe799ef..ffc147262 100644 --- a/src/pkg/net/sendfile_freebsd.go +++ b/src/pkg/net/sendfile_freebsd.go @@ -23,7 +23,7 @@ const maxSendfileSize int = 4 << 20 // if handled == false, sendFile performed no work. func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { // FreeBSD uses 0 as the "until EOF" value. If you pass in more bytes than the - // file contains, it will loop back to the beginning ad nauseum until it's sent + // file contains, it will loop back to the beginning ad nauseam until it's sent // exactly the number of bytes told to. As such, we need to know exactly how many // bytes to send. var remain int64 = 0 diff --git a/src/pkg/net/sendfile_stub.go b/src/pkg/net/sendfile_stub.go index 3660849c1..03426ef0d 100644 --- a/src/pkg/net/sendfile_stub.go +++ b/src/pkg/net/sendfile_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin netbsd openbsd +// +build darwin nacl netbsd openbsd solaris package net diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go index 9194a8ec2..6a2bb9243 100644 --- a/src/pkg/net/server_test.go +++ b/src/pkg/net/server_test.go @@ -9,21 +9,20 @@ import ( "io" "os" "runtime" - "strconv" "testing" "time" ) -func skipServerTest(net, unixsotype, addr string, ipv6, ipv4map, linuxonly bool) bool { +func skipServerTest(net, unixsotype, addr string, ipv6, ipv4map, linuxOnly bool) bool { switch runtime.GOOS { case "linux": - case "plan9", "windows": + case "nacl", "plan9", "windows": // "unix" sockets are not supported on Windows and Plan 9. if net == unixsotype { return true } default: - if net == unixsotype && linuxonly { + if net == unixsotype && linuxOnly { return true } } @@ -42,21 +41,15 @@ func skipServerTest(net, unixsotype, addr string, ipv6, ipv4map, linuxonly bool) return false } -func tempfile(filename string) string { - // use /tmp in case it is prohibited to create - // UNIX sockets in TMPDIR - return "/tmp/" + filename + "." + strconv.Itoa(os.Getpid()) -} - var streamConnServerTests = []struct { - snet string // server side - saddr string - cnet string // client side - caddr string - ipv6 bool // test with underlying AF_INET6 socket - ipv4map bool // test with IPv6 IPv4-mapping functionality - empty bool // test with empty data - linux bool // test with abstract unix domain socket, a Linux-ism + snet string // server side + saddr string + cnet string // client side + caddr string + ipv6 bool // test with underlying AF_INET6 socket + ipv4map bool // test with IPv6 IPv4-mapping functionality + empty bool // test with empty data + linuxOnly bool // test with abstract unix domain socket, a Linux-ism }{ {snet: "tcp", saddr: "", cnet: "tcp", caddr: "127.0.0.1"}, {snet: "tcp", saddr: "0.0.0.0", cnet: "tcp", caddr: "127.0.0.1"}, @@ -93,13 +86,13 @@ var streamConnServerTests = []struct { {snet: "tcp6", saddr: "[::1]", cnet: "tcp6", caddr: "[::1]", ipv6: true}, - {snet: "unix", saddr: tempfile("gotest1.net"), cnet: "unix", caddr: tempfile("gotest1.net.local")}, - {snet: "unix", saddr: "@gotest2/net", cnet: "unix", caddr: "@gotest2/net.local", linux: true}, + {snet: "unix", saddr: testUnixAddr(), cnet: "unix", caddr: testUnixAddr()}, + {snet: "unix", saddr: "@gotest2/net", cnet: "unix", caddr: "@gotest2/net.local", linuxOnly: true}, } func TestStreamConnServer(t *testing.T) { for _, tt := range streamConnServerTests { - if skipServerTest(tt.snet, "unix", tt.saddr, tt.ipv6, tt.ipv4map, tt.linux) { + if skipServerTest(tt.snet, "unix", tt.saddr, tt.ipv6, tt.ipv4map, tt.linuxOnly) { continue } @@ -137,21 +130,28 @@ func TestStreamConnServer(t *testing.T) { } var seqpacketConnServerTests = []struct { - net string - saddr string // server address - caddr string // client address - empty bool // test with empty data + net string + saddr string // server address + caddr string // client address + empty bool // test with empty data + linuxOnly bool // test with abstract unix domain socket, a Linux-ism }{ - {net: "unixpacket", saddr: tempfile("/gotest3.net"), caddr: tempfile("gotest3.net.local")}, - {net: "unixpacket", saddr: "@gotest4/net", caddr: "@gotest4/net.local"}, + {net: "unixpacket", saddr: testUnixAddr(), caddr: testUnixAddr()}, + {net: "unixpacket", saddr: "@gotest4/net", caddr: "@gotest4/net.local", linuxOnly: true}, } func TestSeqpacketConnServer(t *testing.T) { - if runtime.GOOS != "linux" { + switch runtime.GOOS { + case "darwin", "nacl", "openbsd", "plan9", "windows": + fallthrough + case "freebsd": // FreeBSD 8 doesn't support unixpacket t.Skipf("skipping test on %q", runtime.GOOS) } for _, tt := range seqpacketConnServerTests { + if runtime.GOOS != "linux" && tt.linuxOnly { + continue + } listening := make(chan string) done := make(chan int) switch tt.net { @@ -248,15 +248,15 @@ func runStreamConnClient(t *testing.T, net, taddr string, isEmpty bool) { var testDatagram = flag.Bool("datagram", false, "whether to test udp and unixgram") var datagramPacketConnServerTests = []struct { - snet string // server side - saddr string - cnet string // client side - caddr string - ipv6 bool // test with underlying AF_INET6 socket - ipv4map bool // test with IPv6 IPv4-mapping functionality - dial bool // test with Dial or DialUnix - empty bool // test with empty data - linux bool // test with abstract unix domain socket, a Linux-ism + snet string // server side + saddr string + cnet string // client side + caddr string + ipv6 bool // test with underlying AF_INET6 socket + ipv4map bool // test with IPv6 IPv4-mapping functionality + dial bool // test with Dial or DialUnix + empty bool // test with empty data + linuxOnly bool // test with abstract unix domain socket, a Linux-ism }{ {snet: "udp", saddr: "", cnet: "udp", caddr: "127.0.0.1"}, {snet: "udp", saddr: "0.0.0.0", cnet: "udp", caddr: "127.0.0.1"}, @@ -301,12 +301,12 @@ var datagramPacketConnServerTests = []struct { {snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true, empty: true}, {snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true, dial: true, empty: true}, - {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local")}, - {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local"), dial: true}, - {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local"), empty: true}, - {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local"), dial: true, empty: true}, + {snet: "unixgram", saddr: testUnixAddr(), cnet: "unixgram", caddr: testUnixAddr()}, + {snet: "unixgram", saddr: testUnixAddr(), cnet: "unixgram", caddr: testUnixAddr(), dial: true}, + {snet: "unixgram", saddr: testUnixAddr(), cnet: "unixgram", caddr: testUnixAddr(), empty: true}, + {snet: "unixgram", saddr: testUnixAddr(), cnet: "unixgram", caddr: testUnixAddr(), dial: true, empty: true}, - {snet: "unixgram", saddr: "@gotest6/net", cnet: "unixgram", caddr: "@gotest6/net.local", linux: true}, + {snet: "unixgram", saddr: "@gotest6/net", cnet: "unixgram", caddr: "@gotest6/net.local", linuxOnly: true}, } func TestDatagramPacketConnServer(t *testing.T) { @@ -315,7 +315,7 @@ func TestDatagramPacketConnServer(t *testing.T) { } for _, tt := range datagramPacketConnServerTests { - if skipServerTest(tt.snet, "unixgram", tt.saddr, tt.ipv6, tt.ipv4map, tt.linux) { + if skipServerTest(tt.snet, "unixgram", tt.saddr, tt.ipv6, tt.ipv4map, tt.linuxOnly) { continue } diff --git a/src/pkg/net/smtp/example_test.go b/src/pkg/net/smtp/example_test.go new file mode 100644 index 000000000..d551e365a --- /dev/null +++ b/src/pkg/net/smtp/example_test.go @@ -0,0 +1,61 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package smtp_test + +import ( + "fmt" + "log" + "net/smtp" +) + +func Example() { + // Connect to the remote SMTP server. + c, err := smtp.Dial("mail.example.com:25") + if err != nil { + log.Fatal(err) + } + + // Set the sender and recipient first + if err := c.Mail("sender@example.org"); err != nil { + log.Fatal(err) + } + if err := c.Rcpt("recipient@example.net"); err != nil { + log.Fatal(err) + } + + // Send the email body. + wc, err := c.Data() + if err != nil { + log.Fatal(err) + } + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + log.Fatal(err) + } + err = wc.Close() + if err != nil { + log.Fatal(err) + } + + // Send the QUIT command and close the connection. + err = c.Quit() + if err != nil { + log.Fatal(err) + } +} + +func ExamplePlainAuth() { + // Set up authentication information. + auth := smtp.PlainAuth("", "user@example.com", "password", "mail.example.com") + + // Connect to the server, authenticate, set the sender and recipient, + // and send the email all in one step. + to := []string{"recipient@example.net"} + msg := []byte("This is the email body.") + err := smtp.SendMail("mail.example.com:25", auth, "sender@example.org", to, msg) + if err != nil { + log.Fatal(err) + } +} diff --git a/src/pkg/net/smtp/smtp.go b/src/pkg/net/smtp/smtp.go index a0a478a85..87dea442c 100644 --- a/src/pkg/net/smtp/smtp.go +++ b/src/pkg/net/smtp/smtp.go @@ -264,6 +264,8 @@ func (c *Client) Data() (io.WriteCloser, error) { return &dataCloser{c, c.Text.DotWriter()}, nil } +var testHookStartTLS func(*tls.Config) // nil, except for tests + // SendMail connects to the server at addr, switches to TLS if // possible, authenticates with the optional mechanism a if possible, // and then sends an email from address from, to addresses to, with @@ -278,7 +280,11 @@ func SendMail(addr string, a Auth, from string, to []string, msg []byte) error { return err } if ok, _ := c.Extension("STARTTLS"); ok { - if err = c.StartTLS(nil); err != nil { + config := &tls.Config{ServerName: c.serverName} + if testHookStartTLS != nil { + testHookStartTLS(config) + } + if err = c.StartTLS(config); err != nil { return err } } diff --git a/src/pkg/net/smtp/smtp_test.go b/src/pkg/net/smtp/smtp_test.go index 2133dc7c7..3fba1ea5a 100644 --- a/src/pkg/net/smtp/smtp_test.go +++ b/src/pkg/net/smtp/smtp_test.go @@ -7,6 +7,8 @@ package smtp import ( "bufio" "bytes" + "crypto/tls" + "crypto/x509" "io" "net" "net/textproto" @@ -548,3 +550,145 @@ AUTH PLAIN AHVzZXIAcGFzcw== * QUIT ` + +func TestTLSClient(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + errc := make(chan error) + go func() { + errc <- sendMail(ln.Addr().String()) + }() + conn, err := ln.Accept() + if err != nil { + t.Fatalf("failed to accept connection: %v", err) + } + defer conn.Close() + if err := serverHandle(conn, t); err != nil { + t.Fatalf("failed to handle connection: %v", err) + } + if err := <-errc; err != nil { + t.Fatalf("client error: %v", err) + } +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + +type smtpSender struct { + w io.Writer +} + +func (s smtpSender) send(f string) { + s.w.Write([]byte(f + "\r\n")) +} + +// smtp server, finely tailored to deal with our own client only! +func serverHandle(c net.Conn, t *testing.T) error { + send := smtpSender{c}.send + send("220 127.0.0.1 ESMTP service ready") + s := bufio.NewScanner(c) + for s.Scan() { + switch s.Text() { + case "EHLO localhost": + send("250-127.0.0.1 ESMTP offers a warm hug of welcome") + send("250-STARTTLS") + send("250 Ok") + case "STARTTLS": + send("220 Go ahead") + keypair, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + return err + } + config := &tls.Config{Certificates: []tls.Certificate{keypair}} + c = tls.Server(c, config) + defer c.Close() + return serverHandleTLS(c, t) + default: + t.Fatalf("unrecognized command: %q", s.Text()) + } + } + return s.Err() +} + +func serverHandleTLS(c net.Conn, t *testing.T) error { + send := smtpSender{c}.send + s := bufio.NewScanner(c) + for s.Scan() { + switch s.Text() { + case "EHLO localhost": + send("250 Ok") + case "MAIL FROM:<joe1@example.com>": + send("250 Ok") + case "RCPT TO:<joe2@example.com>": + send("250 Ok") + case "DATA": + send("354 send the mail data, end with .") + send("250 Ok") + case "Subject: test": + case "": + case "howdy!": + case ".": + case "QUIT": + send("221 127.0.0.1 Service closing transmission channel") + return nil + default: + t.Fatalf("unrecognized command during TLS: %q", s.Text()) + } + } + return s.Err() +} + +func init() { + testRootCAs := x509.NewCertPool() + testRootCAs.AppendCertsFromPEM(localhostCert) + testHookStartTLS = func(config *tls.Config) { + config.RootCAs = testRootCAs + } +} + +func sendMail(hostPort string) error { + host, _, err := net.SplitHostPort(hostPort) + if err != nil { + return err + } + auth := PlainAuth("", "", "", host) + from := "joe1@example.com" + to := []string{"joe2@example.com"} + return SendMail(hostPort, auth, from, to, []byte("Subject: test\n\nhowdy!")) +} + +// (copied from net/http/httptest) +// localhostCert is a PEM-encoded TLS cert with SAN IPs +// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end +// of ASN.1 time). +// generated from src/pkg/crypto/tls: +// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var localhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD +bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj +bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa +IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA +AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud +EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA +AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk +Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA== +-----END CERTIFICATE-----`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0 +0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV +NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d +AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW +MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD +EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA +1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE= +-----END RSA PRIVATE KEY-----`) diff --git a/src/pkg/net/sock_bsd.go b/src/pkg/net/sock_bsd.go index 6c37109f5..48fb78527 100644 --- a/src/pkg/net/sock_bsd.go +++ b/src/pkg/net/sock_bsd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd netbsd openbsd +// +build darwin dragonfly freebsd nacl netbsd openbsd package net diff --git a/src/pkg/net/sock_cloexec.go b/src/pkg/net/sock_cloexec.go index 3f22cd8f5..dec81855b 100644 --- a/src/pkg/net/sock_cloexec.go +++ b/src/pkg/net/sock_cloexec.go @@ -5,7 +5,7 @@ // This file implements sysSocket and accept for platforms that // provide a fast path for setting SetNonblock and CloseOnExec. -// +build linux +// +build freebsd linux package net @@ -13,18 +13,20 @@ import "syscall" // Wrapper around the socket system call that marks the returned file // descriptor as nonblocking and close-on-exec. -func sysSocket(f, t, p int) (int, error) { - s, err := syscall.Socket(f, t|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, p) - // The SOCK_NONBLOCK and SOCK_CLOEXEC flags were introduced in - // Linux 2.6.27. If we get an EINVAL error, fall back to - // using socket without them. - if err == nil || err != syscall.EINVAL { +func sysSocket(family, sotype, proto int) (int, error) { + s, err := syscall.Socket(family, sotype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, proto) + // On Linux the SOCK_NONBLOCK and SOCK_CLOEXEC flags were + // introduced in 2.6.27 kernel and on FreeBSD both flags were + // introduced in 10 kernel. If we get an EINVAL error on Linux + // or EPROTONOSUPPORT error on FreeBSD, fall back to using + // socket without them. + if err == nil || (err != syscall.EPROTONOSUPPORT && err != syscall.EINVAL) { return s, err } // See ../syscall/exec_unix.go for description of ForkLock. syscall.ForkLock.RLock() - s, err = syscall.Socket(f, t, p) + s, err = syscall.Socket(family, sotype, proto) if err == nil { syscall.CloseOnExec(s) } @@ -41,12 +43,19 @@ func sysSocket(f, t, p int) (int, error) { // Wrapper around the accept system call that marks the returned file // descriptor as nonblocking and close-on-exec. -func accept(fd int) (int, syscall.Sockaddr, error) { - nfd, sa, err := syscall.Accept4(fd, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) - // The accept4 system call was introduced in Linux 2.6.28. If - // we get an ENOSYS or EINVAL error, fall back to using accept. - if err == nil || (err != syscall.ENOSYS && err != syscall.EINVAL) { - return nfd, sa, err +func accept(s int) (int, syscall.Sockaddr, error) { + ns, sa, err := syscall.Accept4(s, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) + // On Linux the accept4 system call was introduced in 2.6.28 + // kernel and on FreeBSD it was introduced in 10 kernel. If we + // get an ENOSYS error on both Linux and FreeBSD, or EINVAL + // error on Linux, fall back to using accept. + switch err { + default: // nil and errors other than the ones listed + return ns, sa, err + case syscall.ENOSYS: // syscall missing + case syscall.EINVAL: // some Linux use this instead of ENOSYS + case syscall.EACCES: // some Linux use this instead of ENOSYS + case syscall.EFAULT: // some Linux use this instead of ENOSYS } // See ../syscall/exec_unix.go for description of ForkLock. @@ -54,16 +63,16 @@ func accept(fd int) (int, syscall.Sockaddr, error) { // because we have put fd.sysfd into non-blocking mode. // However, a call to the File method will put it back into // blocking mode. We can't take that risk, so no use of ForkLock here. - nfd, sa, err = syscall.Accept(fd) + ns, sa, err = syscall.Accept(s) if err == nil { - syscall.CloseOnExec(nfd) + syscall.CloseOnExec(ns) } if err != nil { return -1, nil, err } - if err = syscall.SetNonblock(nfd, true); err != nil { - syscall.Close(nfd) + if err = syscall.SetNonblock(ns, true); err != nil { + syscall.Close(ns) return -1, nil, err } - return nfd, sa, nil + return ns, sa, nil } diff --git a/src/pkg/net/sock_posix.go b/src/pkg/net/sock_posix.go index c2d343c58..a6ef874c9 100644 --- a/src/pkg/net/sock_posix.go +++ b/src/pkg/net/sock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows package net diff --git a/src/pkg/net/sock_solaris.go b/src/pkg/net/sock_solaris.go new file mode 100644 index 000000000..90fe9de89 --- /dev/null +++ b/src/pkg/net/sock_solaris.go @@ -0,0 +1,13 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import "syscall" + +func maxListenerBacklog() int { + // TODO: Implement this + // NOTE: Never return a number bigger than 1<<16 - 1. See issue 5030. + return syscall.SOMAXCONN +} diff --git a/src/pkg/net/sockopt_bsd.go b/src/pkg/net/sockopt_bsd.go index ef6eb8505..77d51d737 100644 --- a/src/pkg/net/sockopt_bsd.go +++ b/src/pkg/net/sockopt_bsd.go @@ -2,16 +2,29 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd netbsd openbsd +// +build darwin dragonfly freebsd nacl netbsd openbsd package net import ( "os" + "runtime" "syscall" ) func setDefaultSockopts(s, family, sotype int, ipv6only bool) error { + if runtime.GOOS == "dragonfly" && sotype != syscall.SOCK_RAW { + // On DragonFly BSD, we adjust the ephemeral port + // range because unlike other BSD systems its default + // port range doesn't conform to IANA recommendation + // as described in RFC 6355 and is pretty narrow. + switch family { + case syscall.AF_INET: + syscall.SetsockoptInt(s, syscall.IPPROTO_IP, syscall.IP_PORTRANGE, syscall.IP_PORTRANGE_HIGH) + case syscall.AF_INET6: + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_PORTRANGE, syscall.IPV6_PORTRANGE_HIGH) + } + } if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW { // Allow both IP versions even if the OS default // is otherwise. Note that some operating systems diff --git a/src/pkg/net/sockopt_plan9.go b/src/pkg/net/sockopt_plan9.go new file mode 100644 index 000000000..8bc689b6c --- /dev/null +++ b/src/pkg/net/sockopt_plan9.go @@ -0,0 +1,13 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +func setKeepAlive(fd *netFD, keepalive bool) error { + if keepalive { + _, e := fd.ctl.WriteAt([]byte("keepalive"), 0) + return e + } + return nil +} diff --git a/src/pkg/net/sockopt_posix.go b/src/pkg/net/sockopt_posix.go index ff3bc6899..921918c37 100644 --- a/src/pkg/net/sockopt_posix.go +++ b/src/pkg/net/sockopt_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows package net diff --git a/src/pkg/net/sockopt_solaris.go b/src/pkg/net/sockopt_solaris.go new file mode 100644 index 000000000..54c20b140 --- /dev/null +++ b/src/pkg/net/sockopt_solaris.go @@ -0,0 +1,32 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" +) + +func setDefaultSockopts(s, family, sotype int, ipv6only bool) error { + if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW { + // Allow both IP versions even if the OS default + // is otherwise. Note that some operating systems + // never admit this option. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only)) + } + // Allow broadcast. + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) +} + +func setDefaultListenerSockopts(s int) error { + // Allow reuse of recently-used addresses. + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)) +} + +func setDefaultMulticastSockopts(s int) error { + // Allow multicast UDP and raw IP datagram sockets to listen + // concurrently across multiple listeners. + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)) +} diff --git a/src/pkg/net/sockoptip_bsd.go b/src/pkg/net/sockoptip_bsd.go index 2199e480d..87132f0f4 100644 --- a/src/pkg/net/sockoptip_bsd.go +++ b/src/pkg/net/sockoptip_bsd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd netbsd openbsd +// +build darwin dragonfly freebsd nacl netbsd openbsd package net diff --git a/src/pkg/net/sockoptip_posix.go b/src/pkg/net/sockoptip_posix.go index c2579be91..b5c80e449 100644 --- a/src/pkg/net/sockoptip_posix.go +++ b/src/pkg/net/sockoptip_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux nacl netbsd openbsd windows package net diff --git a/src/pkg/net/sockoptip_stub.go b/src/pkg/net/sockoptip_stub.go new file mode 100644 index 000000000..dcd3a22b5 --- /dev/null +++ b/src/pkg/net/sockoptip_stub.go @@ -0,0 +1,39 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build solaris + +package net + +import "syscall" + +func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { + // See golang.org/issue/7399. + return syscall.EINVAL +} + +func setIPv4MulticastLoopback(fd *netFD, v bool) error { + // See golang.org/issue/7399. + return syscall.EINVAL +} + +func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error { + // See golang.org/issue/7399. + return syscall.EINVAL +} + +func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error { + // See golang.org/issue/7399. + return syscall.EINVAL +} + +func setIPv6MulticastLoopback(fd *netFD, v bool) error { + // See golang.org/issue/7399. + return syscall.EINVAL +} + +func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error { + // See golang.org/issue/7399. + return syscall.EINVAL +} diff --git a/src/pkg/net/sys_cloexec.go b/src/pkg/net/sys_cloexec.go index bbfcc1a4f..898fb7c0c 100644 --- a/src/pkg/net/sys_cloexec.go +++ b/src/pkg/net/sys_cloexec.go @@ -5,7 +5,7 @@ // This file implements sysSocket and accept for platforms that do not // provide a fast path for setting SetNonblock and CloseOnExec. -// +build darwin dragonfly freebsd netbsd openbsd +// +build darwin dragonfly nacl netbsd openbsd solaris package net @@ -13,10 +13,10 @@ import "syscall" // Wrapper around the socket system call that marks the returned file // descriptor as nonblocking and close-on-exec. -func sysSocket(f, t, p int) (int, error) { +func sysSocket(family, sotype, proto int) (int, error) { // See ../syscall/exec_unix.go for description of ForkLock. syscall.ForkLock.RLock() - s, err := syscall.Socket(f, t, p) + s, err := syscall.Socket(family, sotype, proto) if err == nil { syscall.CloseOnExec(s) } @@ -33,22 +33,22 @@ func sysSocket(f, t, p int) (int, error) { // Wrapper around the accept system call that marks the returned file // descriptor as nonblocking and close-on-exec. -func accept(fd int) (int, syscall.Sockaddr, error) { +func accept(s int) (int, syscall.Sockaddr, error) { // See ../syscall/exec_unix.go for description of ForkLock. // It is probably okay to hold the lock across syscall.Accept // because we have put fd.sysfd into non-blocking mode. // However, a call to the File method will put it back into // blocking mode. We can't take that risk, so no use of ForkLock here. - nfd, sa, err := syscall.Accept(fd) + ns, sa, err := syscall.Accept(s) if err == nil { - syscall.CloseOnExec(nfd) + syscall.CloseOnExec(ns) } if err != nil { return -1, nil, err } - if err = syscall.SetNonblock(nfd, true); err != nil { - syscall.Close(nfd) + if err = syscall.SetNonblock(ns, true); err != nil { + syscall.Close(ns) return -1, nil, err } - return nfd, sa, nil + return ns, sa, nil } diff --git a/src/pkg/net/tcp_test.go b/src/pkg/net/tcp_test.go index 62fd99f5c..c04198ea0 100644 --- a/src/pkg/net/tcp_test.go +++ b/src/pkg/net/tcp_test.go @@ -97,6 +97,7 @@ func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { b.Fatalf("Listen failed: %v", err) } defer ln.Close() + serverSem := make(chan bool, numConcurrent) // Acceptor. go func() { for { @@ -104,9 +105,13 @@ func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { if err != nil { break } + serverSem <- true // Server connection. go func(c Conn) { - defer c.Close() + defer func() { + c.Close() + <-serverSem + }() if timeout { c.SetDeadline(time.Now().Add(time.Hour)) // Not intended to fire. } @@ -119,13 +124,13 @@ func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { }(c) } }() - sem := make(chan bool, numConcurrent) + clientSem := make(chan bool, numConcurrent) for i := 0; i < conns; i++ { - sem <- true + clientSem <- true // Client connection. go func() { defer func() { - <-sem + <-clientSem }() c, err := Dial("tcp", ln.Addr().String()) if err != nil { @@ -144,8 +149,9 @@ func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { } }() } - for i := 0; i < cap(sem); i++ { - sem <- true + for i := 0; i < numConcurrent; i++ { + clientSem <- true + serverSem <- true } } @@ -185,7 +191,8 @@ func benchmarkTCPConcurrentReadWrite(b *testing.B, laddr string) { for p := 0; p < P; p++ { s, err := ln.Accept() if err != nil { - b.Fatalf("Accept failed: %v", err) + b.Errorf("Accept failed: %v", err) + return } servers[p] = s } @@ -217,7 +224,8 @@ func benchmarkTCPConcurrentReadWrite(b *testing.B, laddr string) { buf[0] = v _, err := c.Write(buf[:]) if err != nil { - b.Fatalf("Write failed: %v", err) + b.Errorf("Write failed: %v", err) + return } } }(clients[p]) @@ -232,7 +240,8 @@ func benchmarkTCPConcurrentReadWrite(b *testing.B, laddr string) { for i := 0; i < N; i++ { _, err := s.Read(buf[:]) if err != nil { - b.Fatalf("Read failed: %v", err) + b.Errorf("Read failed: %v", err) + return } pipe <- buf[0] } @@ -250,7 +259,8 @@ func benchmarkTCPConcurrentReadWrite(b *testing.B, laddr string) { buf[0] = v _, err := s.Write(buf[:]) if err != nil { - b.Fatalf("Write failed: %v", err) + b.Errorf("Write failed: %v", err) + return } } s.Close() @@ -263,7 +273,8 @@ func benchmarkTCPConcurrentReadWrite(b *testing.B, laddr string) { for i := 0; i < N; i++ { _, err := c.Read(buf[:]) if err != nil { - b.Fatalf("Read failed: %v", err) + b.Errorf("Read failed: %v", err) + return } } c.Close() @@ -388,7 +399,7 @@ func TestIPv6LinkLocalUnicastTCP(t *testing.T) { {"tcp6", "[" + laddr + "%" + ifi.Name + "]:0", false}, } switch runtime.GOOS { - case "darwin", "freebsd", "opensbd", "netbsd": + case "darwin", "freebsd", "openbsd", "netbsd": tests = append(tests, []test{ {"tcp", "[localhost%" + ifi.Name + "]:0", true}, {"tcp6", "[localhost%" + ifi.Name + "]:0", true}, @@ -460,15 +471,25 @@ func TestTCPConcurrentAccept(t *testing.T) { wg.Done() }() } - for i := 0; i < 10*N; i++ { - c, err := Dial("tcp", ln.Addr().String()) + attempts := 10 * N + fails := 0 + d := &Dialer{Timeout: 200 * time.Millisecond} + for i := 0; i < attempts; i++ { + c, err := d.Dial("tcp", ln.Addr().String()) if err != nil { - t.Fatalf("Dial failed: %v", err) + fails++ + } else { + c.Close() } - c.Close() } ln.Close() wg.Wait() + if fails > attempts/9 { // see issues 7400 and 7541 + t.Fatalf("too many Dial failed: %v", fails) + } + if fails > 0 { + t.Logf("# of failed Dials: %v", fails) + } } func TestTCPReadWriteMallocs(t *testing.T) { diff --git a/src/pkg/net/tcpsock_plan9.go b/src/pkg/net/tcpsock_plan9.go index cf9c0f890..52019d7b4 100644 --- a/src/pkg/net/tcpsock_plan9.go +++ b/src/pkg/net/tcpsock_plan9.go @@ -32,7 +32,7 @@ func (c *TCPConn) CloseRead() error { if !c.ok() { return syscall.EINVAL } - return c.fd.CloseRead() + return c.fd.closeRead() } // CloseWrite shuts down the writing side of the TCP connection. @@ -41,20 +41,21 @@ func (c *TCPConn) CloseWrite() error { if !c.ok() { return syscall.EINVAL } - return c.fd.CloseWrite() + return c.fd.closeWrite() } -// SetLinger sets the behavior of Close() on a connection which still +// SetLinger sets the behavior of Close on a connection which still // has data waiting to be sent or to be acknowledged. // -// If sec < 0 (the default), Close returns immediately and the -// operating system finishes sending the data in the background. +// If sec < 0 (the default), the operating system finishes sending the +// data in the background. // -// If sec == 0, Close returns immediately and the operating system -// discards any unsent or unacknowledged data. +// If sec == 0, the operating system discards any unsent or +// unacknowledged data. // -// If sec > 0, Close blocks for at most sec seconds waiting for data -// to be sent and acknowledged. +// If sec > 0, the data is sent in the background as with sec < 0. On +// some operating systems after sec seconds have elapsed any remaining +// unsent data may be discarded. func (c *TCPConn) SetLinger(sec int) error { return syscall.EPLAN9 } @@ -62,12 +63,18 @@ func (c *TCPConn) SetLinger(sec int) error { // SetKeepAlive sets whether the operating system should send // keepalive messages on the connection. func (c *TCPConn) SetKeepAlive(keepalive bool) error { - return syscall.EPLAN9 + if !c.ok() { + return syscall.EPLAN9 + } + return setKeepAlive(c.fd, keepalive) } // SetKeepAlivePeriod sets period between keep alives. func (c *TCPConn) SetKeepAlivePeriod(d time.Duration) error { - return syscall.EPLAN9 + if !c.ok() { + return syscall.EPLAN9 + } + return setKeepAlivePeriod(c.fd, d) } // SetNoDelay controls whether the operating system should delay diff --git a/src/pkg/net/tcpsock_posix.go b/src/pkg/net/tcpsock_posix.go index 00c692e42..b79b115ca 100644 --- a/src/pkg/net/tcpsock_posix.go +++ b/src/pkg/net/tcpsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows package net @@ -78,7 +78,7 @@ func (c *TCPConn) CloseRead() error { if !c.ok() { return syscall.EINVAL } - return c.fd.CloseRead() + return c.fd.closeRead() } // CloseWrite shuts down the writing side of the TCP connection. @@ -87,20 +87,21 @@ func (c *TCPConn) CloseWrite() error { if !c.ok() { return syscall.EINVAL } - return c.fd.CloseWrite() + return c.fd.closeWrite() } -// SetLinger sets the behavior of Close() on a connection which still +// SetLinger sets the behavior of Close on a connection which still // has data waiting to be sent or to be acknowledged. // -// If sec < 0 (the default), Close returns immediately and the -// operating system finishes sending the data in the background. +// If sec < 0 (the default), the operating system finishes sending the +// data in the background. // -// If sec == 0, Close returns immediately and the operating system -// discards any unsent or unacknowledged data. +// If sec == 0, the operating system discards any unsent or +// unacknowledged data. // -// If sec > 0, Close blocks for at most sec seconds waiting for data -// to be sent and acknowledged. +// If sec > 0, the data is sent in the background as with sec < 0. On +// some operating systems after sec seconds have elapsed any remaining +// unsent data may be discarded. func (c *TCPConn) SetLinger(sec int) error { if !c.ok() { return syscall.EINVAL diff --git a/src/pkg/net/tcpsockopt_dragonfly.go b/src/pkg/net/tcpsockopt_dragonfly.go new file mode 100644 index 000000000..d10a77773 --- /dev/null +++ b/src/pkg/net/tcpsockopt_dragonfly.go @@ -0,0 +1,29 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" + "time" +) + +// Set keep alive period. +func setKeepAlivePeriod(fd *netFD, d time.Duration) error { + if err := fd.incref(); err != nil { + return err + } + defer fd.decref() + + // The kernel expects milliseconds so round to next highest millisecond. + d += (time.Millisecond - time.Nanosecond) + msecs := int(time.Duration(d.Nanoseconds()) / time.Millisecond) + + err := os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, msecs)) + if err != nil { + return err + } + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, msecs)) +} diff --git a/src/pkg/net/tcpsockopt_plan9.go b/src/pkg/net/tcpsockopt_plan9.go new file mode 100644 index 000000000..0e7a6647c --- /dev/null +++ b/src/pkg/net/tcpsockopt_plan9.go @@ -0,0 +1,18 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TCP socket options for plan9 + +package net + +import ( + "time" +) + +// Set keep alive period. +func setKeepAlivePeriod(fd *netFD, d time.Duration) error { + cmd := "keepalive " + string(int64(d/time.Millisecond)) + _, e := fd.ctl.WriteAt([]byte(cmd), 0) + return e +} diff --git a/src/pkg/net/tcpsockopt_posix.go b/src/pkg/net/tcpsockopt_posix.go index e03476ac6..6484bad4b 100644 --- a/src/pkg/net/tcpsockopt_posix.go +++ b/src/pkg/net/tcpsockopt_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows package net diff --git a/src/pkg/net/tcpsockopt_solaris.go b/src/pkg/net/tcpsockopt_solaris.go new file mode 100644 index 000000000..eaab6b678 --- /dev/null +++ b/src/pkg/net/tcpsockopt_solaris.go @@ -0,0 +1,27 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TCP socket options for solaris + +package net + +import ( + "os" + "syscall" + "time" +) + +// Set keep alive period. +func setKeepAlivePeriod(fd *netFD, d time.Duration) error { + if err := fd.incref(); err != nil { + return err + } + defer fd.decref() + + // The kernel expects seconds so round to next highest second. + d += (time.Second - time.Nanosecond) + secs := int(d.Seconds()) + + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.SO_KEEPALIVE, secs)) +} diff --git a/src/pkg/net/tcpsockopt_unix.go b/src/pkg/net/tcpsockopt_unix.go index 89d9143b5..2693a541d 100644 --- a/src/pkg/net/tcpsockopt_unix.go +++ b/src/pkg/net/tcpsockopt_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build dragonfly freebsd linux netbsd +// +build freebsd linux nacl netbsd package net diff --git a/src/pkg/net/tcpsockopt_windows.go b/src/pkg/net/tcpsockopt_windows.go index 0bf4312f2..8ef140797 100644 --- a/src/pkg/net/tcpsockopt_windows.go +++ b/src/pkg/net/tcpsockopt_windows.go @@ -7,7 +7,10 @@ package net import ( + "os" + "syscall" "time" + "unsafe" ) func setKeepAlivePeriod(fd *netFD, d time.Duration) error { @@ -16,6 +19,16 @@ func setKeepAlivePeriod(fd *netFD, d time.Duration) error { } defer fd.decref() - // We can't actually set this per connection. Act as a noop rather than an error. - return nil + // Windows expects milliseconds so round to next highest millisecond. + d += (time.Millisecond - time.Nanosecond) + millis := uint32(d / time.Millisecond) + ka := syscall.TCPKeepalive{ + OnOff: 1, + Time: millis, + Interval: millis, + } + ret := uint32(0) + size := uint32(unsafe.Sizeof(ka)) + err := syscall.WSAIoctl(fd.sysfd, syscall.SIO_KEEPALIVE_VALS, (*byte)(unsafe.Pointer(&ka)), size, nil, 0, &ret, nil, 0) + return os.NewSyscallError("WSAIoctl", err) } diff --git a/src/pkg/net/testdata/resolv.conf b/src/pkg/net/testdata/resolv.conf new file mode 100644 index 000000000..b5972e09c --- /dev/null +++ b/src/pkg/net/testdata/resolv.conf @@ -0,0 +1,5 @@ +# /etc/resolv.conf + +domain Home +nameserver 192.168.1.1 +options ndots:5 timeout:10 attempts:3 rotate diff --git a/src/pkg/net/textproto/reader.go b/src/pkg/net/textproto/reader.go index b0c07413c..eea9207f2 100644 --- a/src/pkg/net/textproto/reader.go +++ b/src/pkg/net/textproto/reader.go @@ -562,19 +562,12 @@ const toLower = 'a' - 'A' // allowed to mutate the provided byte slice before returning the // string. func canonicalMIMEHeaderKey(a []byte) string { - // Look for it in commonHeaders , so that we can avoid an - // allocation by sharing the strings among all users - // of textproto. If we don't find it, a has been canonicalized - // so just return string(a). upper := true - lo := 0 - hi := len(commonHeaders) - for i := 0; i < len(a); i++ { + for i, c := range a { // Canonicalize: first letter upper case // and upper case after each dash. // (Host, User-Agent, If-Modified-Since). // MIME headers are ASCII only, so no Unicode issues. - c := a[i] if c == ' ' { c = '-' } else if upper && 'a' <= c && c <= 'z' { @@ -584,60 +577,61 @@ func canonicalMIMEHeaderKey(a []byte) string { } a[i] = c upper = c == '-' // for next time - - if lo < hi { - for lo < hi && (len(commonHeaders[lo]) <= i || commonHeaders[lo][i] < c) { - lo++ - } - for hi > lo && commonHeaders[hi-1][i] > c { - hi-- - } - } } - if lo < hi && len(commonHeaders[lo]) == len(a) { - return commonHeaders[lo] + // The compiler recognizes m[string(byteSlice)] as a special + // case, so a copy of a's bytes into a new string does not + // happen in this map lookup: + if v := commonHeader[string(a)]; v != "" { + return v } return string(a) } -var commonHeaders = []string{ - "Accept", - "Accept-Charset", - "Accept-Encoding", - "Accept-Language", - "Accept-Ranges", - "Cache-Control", - "Cc", - "Connection", - "Content-Id", - "Content-Language", - "Content-Length", - "Content-Transfer-Encoding", - "Content-Type", - "Cookie", - "Date", - "Dkim-Signature", - "Etag", - "Expires", - "From", - "Host", - "If-Modified-Since", - "If-None-Match", - "In-Reply-To", - "Last-Modified", - "Location", - "Message-Id", - "Mime-Version", - "Pragma", - "Received", - "Return-Path", - "Server", - "Set-Cookie", - "Subject", - "To", - "User-Agent", - "Via", - "X-Forwarded-For", - "X-Imforwards", - "X-Powered-By", +// commonHeader interns common header strings. +var commonHeader = make(map[string]string) + +func init() { + for _, v := range []string{ + "Accept", + "Accept-Charset", + "Accept-Encoding", + "Accept-Language", + "Accept-Ranges", + "Cache-Control", + "Cc", + "Connection", + "Content-Id", + "Content-Language", + "Content-Length", + "Content-Transfer-Encoding", + "Content-Type", + "Cookie", + "Date", + "Dkim-Signature", + "Etag", + "Expires", + "From", + "Host", + "If-Modified-Since", + "If-None-Match", + "In-Reply-To", + "Last-Modified", + "Location", + "Message-Id", + "Mime-Version", + "Pragma", + "Received", + "Return-Path", + "Server", + "Set-Cookie", + "Subject", + "To", + "User-Agent", + "Via", + "X-Forwarded-For", + "X-Imforwards", + "X-Powered-By", + } { + commonHeader[v] = v + } } diff --git a/src/pkg/net/textproto/reader_test.go b/src/pkg/net/textproto/reader_test.go index cc12912b6..cbc0ed183 100644 --- a/src/pkg/net/textproto/reader_test.go +++ b/src/pkg/net/textproto/reader_test.go @@ -247,24 +247,20 @@ func TestRFC959Lines(t *testing.T) { } func TestCommonHeaders(t *testing.T) { - // need to disable the commonHeaders-based optimization - // during this check, or we'd not be testing anything - oldch := commonHeaders - commonHeaders = []string{} - defer func() { commonHeaders = oldch }() - - last := "" - for _, h := range oldch { - if last > h { - t.Errorf("%v is out of order", h) - } - if last == h { - t.Errorf("%v is duplicated", h) + for h := range commonHeader { + if h != CanonicalMIMEHeaderKey(h) { + t.Errorf("Non-canonical header %q in commonHeader", h) } - if canon := CanonicalMIMEHeaderKey(h); h != canon { - t.Errorf("%v is not canonical", h) + } + b := []byte("content-Length") + want := "Content-Length" + n := testing.AllocsPerRun(200, func() { + if x := canonicalMIMEHeaderKey(b); x != want { + t.Fatalf("canonicalMIMEHeaderKey(%q) = %q; want %q", b, x, want) } - last = h + }) + if n > 0 { + t.Errorf("canonicalMIMEHeaderKey allocs = %v; want 0", n) } } diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index 35d427a69..9ef0c4d15 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -120,6 +120,9 @@ func TestReadTimeout(t *testing.T) { t.Fatalf("Read: expected err %v, got %v", errClosing, err) } default: + if err == io.EOF && runtime.GOOS == "nacl" { // close enough; golang.org/issue/8044 + break + } if err != errClosing { t.Fatalf("Read: expected err %v, got %v", errClosing, err) } @@ -348,7 +351,8 @@ func TestReadWriteDeadline(t *testing.T) { go func() { c, err := ln.Accept() if err != nil { - t.Fatalf("Accept: %v", err) + t.Errorf("Accept: %v", err) + return } defer c.Close() lnquit <- true @@ -493,10 +497,7 @@ func testVariousDeadlines(t *testing.T, maxProcs int) { clientc <- copyRes{n, err, d} }() - tooLong := 2 * time.Second - if runtime.GOOS == "windows" { - tooLong = 5 * time.Second - } + tooLong := 5 * time.Second select { case res := <-clientc: if isTimeout(res.err) { @@ -536,7 +537,8 @@ func TestReadDeadlineDataAvailable(t *testing.T) { go func() { c, err := ln.Accept() if err != nil { - t.Fatalf("Accept: %v", err) + t.Errorf("Accept: %v", err) + return } defer c.Close() n, err := c.Write([]byte(msg)) @@ -574,7 +576,8 @@ func TestWriteDeadlineBufferAvailable(t *testing.T) { go func() { c, err := ln.Accept() if err != nil { - t.Fatalf("Accept: %v", err) + t.Errorf("Accept: %v", err) + return } defer c.Close() c.SetWriteDeadline(time.Now().Add(-5 * time.Second)) // in the past @@ -610,7 +613,8 @@ func TestAcceptDeadlineConnectionAvailable(t *testing.T) { go func() { c, err := Dial("tcp", ln.Addr().String()) if err != nil { - t.Fatalf("Dial: %v", err) + t.Errorf("Dial: %v", err) + return } defer c.Close() var buf [1]byte @@ -669,7 +673,8 @@ func TestProlongTimeout(t *testing.T) { s, err := ln.Accept() connected <- true if err != nil { - t.Fatalf("ln.Accept: %v", err) + t.Errorf("ln.Accept: %v", err) + return } defer s.Close() s.SetDeadline(time.Now().Add(time.Hour)) @@ -706,7 +711,7 @@ func TestProlongTimeout(t *testing.T) { func TestDeadlineRace(t *testing.T) { switch runtime.GOOS { - case "plan9": + case "nacl", "plan9": t.Skipf("skipping test on %q", runtime.GOOS) } diff --git a/src/pkg/net/udp_test.go b/src/pkg/net/udp_test.go index 6f4d2152c..e1778779c 100644 --- a/src/pkg/net/udp_test.go +++ b/src/pkg/net/udp_test.go @@ -201,6 +201,10 @@ func TestIPv6LinkLocalUnicastUDP(t *testing.T) { {"udp", "[" + laddr + "%" + ifi.Name + "]:0", false}, {"udp6", "[" + laddr + "%" + ifi.Name + "]:0", false}, } + // The first udp test fails on DragonFly - see issue 7473. + if runtime.GOOS == "dragonfly" { + tests = tests[1:] + } switch runtime.GOOS { case "darwin", "dragonfly", "freebsd", "openbsd", "netbsd": tests = append(tests, []test{ diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go index 0dd0dbd71..4c99ae4af 100644 --- a/src/pkg/net/udpsock.go +++ b/src/pkg/net/udpsock.go @@ -4,10 +4,6 @@ package net -import "errors" - -var ErrWriteToConnected = errors.New("use of WriteTo with pre-connected UDP") - // UDPAddr represents the address of a UDP end point. type UDPAddr struct { IP IP diff --git a/src/pkg/net/udpsock_plan9.go b/src/pkg/net/udpsock_plan9.go index 73621706d..510ac5e4a 100644 --- a/src/pkg/net/udpsock_plan9.go +++ b/src/pkg/net/udpsock_plan9.go @@ -190,7 +190,8 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { if err != nil { return nil, err } - return newUDPConn(l.netFD()), nil + fd, err := l.netFD() + return newUDPConn(fd), err } // ListenMulticastUDP listens for incoming multicast UDP packets diff --git a/src/pkg/net/udpsock_posix.go b/src/pkg/net/udpsock_posix.go index 142da8186..5dfba94e9 100644 --- a/src/pkg/net/udpsock_posix.go +++ b/src/pkg/net/udpsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows package net @@ -64,7 +64,7 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) { if !c.ok() { return 0, nil, syscall.EINVAL } - n, sa, err := c.fd.ReadFrom(b) + n, sa, err := c.fd.readFrom(b) switch sa := sa.(type) { case *syscall.SockaddrInet4: addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} @@ -93,7 +93,7 @@ func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, return 0, 0, 0, nil, syscall.EINVAL } var sa syscall.Sockaddr - n, oobn, flags, sa, err = c.fd.ReadMsg(b, oob) + n, oobn, flags, sa, err = c.fd.readMsg(b, oob) switch sa := sa.(type) { case *syscall.SockaddrInet4: addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} @@ -124,7 +124,7 @@ func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { if err != nil { return 0, &OpError{"write", c.fd.net, addr, err} } - return c.fd.WriteTo(b, sa) + return c.fd.writeTo(b, sa) } // WriteTo implements the PacketConn WriteTo method. @@ -156,7 +156,7 @@ func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err er if err != nil { return 0, 0, &OpError{"write", c.fd.net, addr, err} } - return c.fd.WriteMsg(b, oob, sa) + return c.fd.writeMsg(b, oob, sa) } // DialUDP connects to the remote address raddr on the network net, diff --git a/src/pkg/net/unicast_posix_test.go b/src/pkg/net/unicast_posix_test.go index 5deb8f47c..452ac9254 100644 --- a/src/pkg/net/unicast_posix_test.go +++ b/src/pkg/net/unicast_posix_test.go @@ -166,9 +166,12 @@ var dualStackListenerTests = []struct { } // TestDualStackTCPListener tests both single and double listen -// to a test listener with various address families, differnet +// to a test listener with various address families, different // listening address and same port. func TestDualStackTCPListener(t *testing.T) { + if testing.Short() { + t.Skip("skipping in -short mode, see issue 5001") + } switch runtime.GOOS { case "plan9": t.Skipf("skipping test on %q", runtime.GOOS) @@ -178,7 +181,7 @@ func TestDualStackTCPListener(t *testing.T) { } for _, tt := range dualStackListenerTests { - if tt.wildcard && (testing.Short() || !*testExternal) { + if tt.wildcard && !*testExternal { continue } switch runtime.GOOS { diff --git a/src/pkg/net/unix_test.go b/src/pkg/net/unix_test.go index 91df3ff88..05643ddf9 100644 --- a/src/pkg/net/unix_test.go +++ b/src/pkg/net/unix_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !plan9,!windows +// +build !nacl,!plan9,!windows package net @@ -151,6 +151,73 @@ func TestUnixAutobindClose(t *testing.T) { ln.Close() } +func TestUnixgramWrite(t *testing.T) { + addr := testUnixAddr() + laddr, err := ResolveUnixAddr("unixgram", addr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c, err := ListenPacket("unixgram", addr) + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + defer os.Remove(addr) + defer c.Close() + + testUnixgramWriteConn(t, laddr) + testUnixgramWritePacketConn(t, laddr) +} + +func testUnixgramWriteConn(t *testing.T, raddr *UnixAddr) { + c, err := Dial("unixgram", raddr.String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c.Close() + + if _, err := c.(*UnixConn).WriteToUnix([]byte("Connection-oriented mode socket"), raddr); err == nil { + t.Fatal("WriteToUnix should fail") + } else if err.(*OpError).Err != ErrWriteToConnected { + t.Fatalf("WriteToUnix should fail as ErrWriteToConnected: %v", err) + } + if _, err = c.(*UnixConn).WriteTo([]byte("Connection-oriented mode socket"), raddr); err == nil { + t.Fatal("WriteTo should fail") + } else if err.(*OpError).Err != ErrWriteToConnected { + t.Fatalf("WriteTo should fail as ErrWriteToConnected: %v", err) + } + if _, _, err = c.(*UnixConn).WriteMsgUnix([]byte("Connection-oriented mode socket"), nil, raddr); err == nil { + t.Fatal("WriteTo should fail") + } else if err.(*OpError).Err != ErrWriteToConnected { + t.Fatalf("WriteMsgUnix should fail as ErrWriteToConnected: %v", err) + } + if _, err := c.Write([]byte("Connection-oriented mode socket")); err != nil { + t.Fatalf("Write failed: %v", err) + } +} + +func testUnixgramWritePacketConn(t *testing.T, raddr *UnixAddr) { + addr := testUnixAddr() + c, err := ListenPacket("unixgram", addr) + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + defer os.Remove(addr) + defer c.Close() + + if _, err := c.(*UnixConn).WriteToUnix([]byte("Connectionless mode socket"), raddr); err != nil { + t.Fatalf("WriteToUnix failed: %v", err) + } + if _, err := c.WriteTo([]byte("Connectionless mode socket"), raddr); err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + if _, _, err := c.(*UnixConn).WriteMsgUnix([]byte("Connectionless mode socket"), nil, raddr); err != nil { + t.Fatalf("WriteMsgUnix failed: %v", err) + } + if _, err := c.(*UnixConn).Write([]byte("Connectionless mode socket")); err == nil { + t.Fatal("Write should fail") + } +} + func TestUnixConnLocalAndRemoteNames(t *testing.T) { for _, laddr := range []string{"", testUnixAddr()} { laddr := laddr diff --git a/src/pkg/net/unixsock_posix.go b/src/pkg/net/unixsock_posix.go index b82f3cee0..2610779bf 100644 --- a/src/pkg/net/unixsock_posix.go +++ b/src/pkg/net/unixsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows package net @@ -124,7 +124,7 @@ func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err error) { if !c.ok() { return 0, nil, syscall.EINVAL } - n, sa, err := c.fd.ReadFrom(b) + n, sa, err := c.fd.readFrom(b) switch sa := sa.(type) { case *syscall.SockaddrUnix: if sa.Name != "" { @@ -151,7 +151,7 @@ func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAdd if !c.ok() { return 0, 0, 0, nil, syscall.EINVAL } - n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob) + n, oobn, flags, sa, err := c.fd.readMsg(b, oob) switch sa := sa.(type) { case *syscall.SockaddrUnix: if sa.Name != "" { @@ -171,6 +171,9 @@ func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err error) { if !c.ok() { return 0, syscall.EINVAL } + if c.fd.isConnected { + return 0, &OpError{Op: "write", Net: c.fd.net, Addr: addr, Err: ErrWriteToConnected} + } if addr == nil { return 0, &OpError{Op: "write", Net: c.fd.net, Addr: nil, Err: errMissingAddress} } @@ -178,7 +181,7 @@ func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err error) { return 0, syscall.EAFNOSUPPORT } sa := &syscall.SockaddrUnix{Name: addr.Name} - return c.fd.WriteTo(b, sa) + return c.fd.writeTo(b, sa) } // WriteTo implements the PacketConn WriteTo method. @@ -200,14 +203,17 @@ func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err if !c.ok() { return 0, 0, syscall.EINVAL } + if c.fd.sotype == syscall.SOCK_DGRAM && c.fd.isConnected { + return 0, 0, &OpError{Op: "write", Net: c.fd.net, Addr: addr, Err: ErrWriteToConnected} + } if addr != nil { if addr.Net != sotypeToNet(c.fd.sotype) { return 0, 0, syscall.EAFNOSUPPORT } sa := &syscall.SockaddrUnix{Name: addr.Name} - return c.fd.WriteMsg(b, oob, sa) + return c.fd.writeMsg(b, oob, sa) } - return c.fd.WriteMsg(b, oob, nil) + return c.fd.writeMsg(b, oob, nil) } // CloseRead shuts down the reading side of the Unix domain connection. @@ -216,7 +222,7 @@ func (c *UnixConn) CloseRead() error { if !c.ok() { return syscall.EINVAL } - return c.fd.CloseRead() + return c.fd.closeRead() } // CloseWrite shuts down the writing side of the Unix domain connection. @@ -225,7 +231,7 @@ func (c *UnixConn) CloseWrite() error { if !c.ok() { return syscall.EINVAL } - return c.fd.CloseWrite() + return c.fd.closeWrite() } // DialUnix connects to the remote address raddr on the network net, @@ -280,7 +286,11 @@ func (l *UnixListener) AcceptUnix() (*UnixConn, error) { if l == nil || l.fd == nil { return nil, syscall.EINVAL } - fd, err := l.fd.accept(sockaddrToUnix) + toAddr := sockaddrToUnix + if l.fd.sotype == syscall.SOCK_SEQPACKET { + toAddr = sockaddrToUnixpacket + } + fd, err := l.fd.accept(toAddr) if err != nil { return nil, err } diff --git a/src/pkg/net/url/url.go b/src/pkg/net/url/url.go index 3b3787202..75f650a27 100644 --- a/src/pkg/net/url/url.go +++ b/src/pkg/net/url/url.go @@ -502,7 +502,7 @@ func (v Values) Set(key, value string) { v[key] = []string{value} } -// Add adds the key to value. It appends to any existing +// Add adds the value to key. It appends to any existing // values associated with key. func (v Values) Add(key, value string) { v[key] = append(v[key], value) diff --git a/src/pkg/net/url/url_test.go b/src/pkg/net/url/url_test.go index 7578eb15b..cad758f23 100644 --- a/src/pkg/net/url/url_test.go +++ b/src/pkg/net/url/url_test.go @@ -251,6 +251,17 @@ var urltests = []URLTest{ }, "file:///home/adg/rabbits", }, + // "Windows" paths are no exception to the rule. + // See golang.org/issue/6027, especially comment #9. + { + "file:///C:/FooBar/Baz.txt", + &URL{ + Scheme: "file", + Host: "", + Path: "/C:/FooBar/Baz.txt", + }, + "file:///C:/FooBar/Baz.txt", + }, // case-insensitive scheme { "MaIlTo:webmaster@golang.org", diff --git a/src/pkg/net/z_last_test.go b/src/pkg/net/z_last_test.go new file mode 100644 index 000000000..4f6a54a56 --- /dev/null +++ b/src/pkg/net/z_last_test.go @@ -0,0 +1,37 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "flag" + "fmt" + "testing" +) + +var testDNSFlood = flag.Bool("dnsflood", false, "whether to test dns query flooding") + +func TestDNSThreadLimit(t *testing.T) { + if !*testDNSFlood { + t.Skip("test disabled; use -dnsflood to enable") + } + + const N = 10000 + c := make(chan int, N) + for i := 0; i < N; i++ { + go func(i int) { + LookupIP(fmt.Sprintf("%d.net-test.golang.org", i)) + c <- 1 + }(i) + } + // Don't bother waiting for the stragglers; stop at 0.9 N. + for i := 0; i < N*9/10; i++ { + if i%100 == 0 { + //println("TestDNSThreadLimit:", i) + } + <-c + } + + // If we're still here, it worked. +} |