diff options
author | Tianon Gravi <admwiggin@gmail.com> | 2015-01-15 11:54:00 -0700 |
---|---|---|
committer | Tianon Gravi <admwiggin@gmail.com> | 2015-01-15 11:54:00 -0700 |
commit | f154da9e12608589e8d5f0508f908a0c3e88a1bb (patch) | |
tree | f8255d51e10c6f1e0ed69702200b966c9556a431 /src/net/http/httptest | |
parent | 8d8329ed5dfb9622c82a9fbec6fd99a580f9c9f6 (diff) | |
download | golang-upstream/1.4.tar.gz |
Imported Upstream version 1.4upstream/1.4
Diffstat (limited to 'src/net/http/httptest')
-rw-r--r-- | src/net/http/httptest/example_test.go | 50 | ||||
-rw-r--r-- | src/net/http/httptest/recorder.go | 72 | ||||
-rw-r--r-- | src/net/http/httptest/recorder_test.go | 90 | ||||
-rw-r--r-- | src/net/http/httptest/server.go | 228 | ||||
-rw-r--r-- | src/net/http/httptest/server_test.go | 29 |
5 files changed, 469 insertions, 0 deletions
diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go new file mode 100644 index 000000000..42a0ec953 --- /dev/null +++ b/src/net/http/httptest/example_test.go @@ -0,0 +1,50 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest_test + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" +) + +func ExampleResponseRecorder() { + handler := func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "something failed", http.StatusInternalServerError) + } + + req, err := http.NewRequest("GET", "http://example.com/foo", nil) + if err != nil { + log.Fatal(err) + } + + w := httptest.NewRecorder() + handler(w, req) + + fmt.Printf("%d - %s", w.Code, w.Body.String()) + // Output: 500 - something failed +} + +func ExampleServer() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + res, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", greeting) + // Output: Hello, client +} diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go new file mode 100644 index 000000000..5451f5423 --- /dev/null +++ b/src/net/http/httptest/recorder.go @@ -0,0 +1,72 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httptest provides utilities for HTTP testing. +package httptest + +import ( + "bytes" + "net/http" +) + +// ResponseRecorder is an implementation of http.ResponseWriter that +// records its mutations for later inspection in tests. +type ResponseRecorder struct { + Code int // the HTTP response code from WriteHeader + HeaderMap http.Header // the HTTP response headers + Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to + Flushed bool + + wroteHeader bool +} + +// NewRecorder returns an initialized ResponseRecorder. +func NewRecorder() *ResponseRecorder { + return &ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + Code: 200, + } +} + +// DefaultRemoteAddr is the default remote address to return in RemoteAddr if +// an explicit DefaultRemoteAddr isn't set on ResponseRecorder. +const DefaultRemoteAddr = "1.2.3.4" + +// Header returns the response headers. +func (rw *ResponseRecorder) Header() http.Header { + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m +} + +// Write always succeeds and writes to rw.Body, if not nil. +func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + if !rw.wroteHeader { + rw.WriteHeader(200) + } + if rw.Body != nil { + rw.Body.Write(buf) + } + return len(buf), nil +} + +// WriteHeader sets rw.Code. +func (rw *ResponseRecorder) WriteHeader(code int) { + if !rw.wroteHeader { + rw.Code = code + } + rw.wroteHeader = true +} + +// Flush sets rw.Flushed to true. +func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } + rw.Flushed = true +} diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go new file mode 100644 index 000000000..2b563260c --- /dev/null +++ b/src/net/http/httptest/recorder_test.go @@ -0,0 +1,90 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "fmt" + "net/http" + "testing" +) + +func TestRecorder(t *testing.T) { + type checkFunc func(*ResponseRecorder) error + check := func(fns ...checkFunc) []checkFunc { return fns } + + hasStatus := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Code != wantCode { + return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) + } + return nil + } + } + hasContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Body.String() != want { + return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) + } + return nil + } + } + hasFlush := func(want bool) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Flushed != want { + return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) + } + return nil + } + } + + tests := []struct { + name string + h func(w http.ResponseWriter, r *http.Request) + checks []checkFunc + }{ + { + "200 default", + func(w http.ResponseWriter, r *http.Request) {}, + check(hasStatus(200), hasContents("")), + }, + { + "first code only", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + w.WriteHeader(202) + w.Write([]byte("hi")) + }, + check(hasStatus(201), hasContents("hi")), + }, + { + "write sends 200", + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi first")) + w.WriteHeader(201) + w.WriteHeader(202) + }, + check(hasStatus(200), hasContents("hi first"), hasFlush(false)), + }, + { + "flush", + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // also sends a 200 + w.WriteHeader(201) + }, + check(hasStatus(200), hasFlush(true)), + }, + } + r, _ := http.NewRequest("GET", "http://foo.com/", nil) + for _, tt := range tests { + h := http.HandlerFunc(tt.h) + rec := NewRecorder() + h.ServeHTTP(rec, r) + for _, check := range tt.checks { + if err := check(rec); err != nil { + t.Errorf("%s: %v", tt.name, err) + } + } + } +} diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go new file mode 100644 index 000000000..789e7bf41 --- /dev/null +++ b/src/net/http/httptest/server.go @@ -0,0 +1,228 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Implementation of Server + +package httptest + +import ( + "crypto/tls" + "flag" + "fmt" + "net" + "net/http" + "os" + "sync" +) + +// A Server is an HTTP server listening on a system-chosen port on the +// local loopback interface, for use in end-to-end HTTP tests. +type Server struct { + URL string // base URL of form http://ipaddr:port with no trailing slash + Listener net.Listener + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config + + // Config may be changed after calling NewUnstartedServer and + // before Start or StartTLS. + Config *http.Server + + // wg counts the number of outstanding HTTP requests on this server. + // Close blocks until all requests are finished. + wg sync.WaitGroup +} + +// historyListener keeps track of all connections that it's ever +// accepted. +type historyListener struct { + net.Listener + sync.Mutex // protects history + history []net.Conn +} + +func (hs *historyListener) Accept() (c net.Conn, err error) { + c, err = hs.Listener.Accept() + if err == nil { + hs.Lock() + hs.history = append(hs.history, c) + hs.Unlock() + } + return +} + +func newLocalListener() net.Listener { + if *serve != "" { + l, err := net.Listen("tcp", *serve) + if err != nil { + panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err)) + } + return l + } + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + } + } + return l +} + +// When debugging a particular http server-based test, +// this flag lets you run +// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000 +// to start the broken server so you can interact with it manually. +var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks") + +// NewServer starts and returns a new Server. +// The caller should call Close when finished, to shut it down. +func NewServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.Start() + return ts +} + +// NewUnstartedServer returns a new Server but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedServer(handler http.Handler) *Server { + return &Server{ + Listener: newLocalListener(), + Config: &http.Server{Handler: handler}, + } +} + +// Start starts a server from NewUnstartedServer. +func (s *Server) Start() { + if s.URL != "" { + panic("Server already started") + } + s.Listener = &historyListener{Listener: s.Listener} + s.URL = "http://" + s.Listener.Addr().String() + s.wrapHandler() + go s.Config.Serve(s.Listener) + if *serve != "" { + fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) + select {} + } +} + +// StartTLS starts TLS on a server from NewUnstartedServer. +func (s *Server) StartTLS() { + if s.URL != "" { + panic("Server already started") + } + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + + existingConfig := s.TLS + s.TLS = new(tls.Config) + if existingConfig != nil { + *s.TLS = *existingConfig + } + if s.TLS.NextProtos == nil { + s.TLS.NextProtos = []string{"http/1.1"} + } + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} + } + tlsListener := tls.NewListener(s.Listener, s.TLS) + + s.Listener = &historyListener{Listener: tlsListener} + s.URL = "https://" + s.Listener.Addr().String() + s.wrapHandler() + go s.Config.Serve(s.Listener) +} + +func (s *Server) wrapHandler() { + h := s.Config.Handler + if h == nil { + h = http.DefaultServeMux + } + s.Config.Handler = &waitGroupHandler{ + s: s, + h: h, + } +} + +// NewTLSServer starts and returns a new Server using TLS. +// The caller should call Close when finished, to shut it down. +func NewTLSServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.StartTLS() + return ts +} + +// Close shuts down the server and blocks until all outstanding +// requests on this server have completed. +func (s *Server) Close() { + s.Listener.Close() + s.wg.Wait() + s.CloseClientConnections() + if t, ok := http.DefaultTransport.(*http.Transport); ok { + t.CloseIdleConnections() + } +} + +// CloseClientConnections closes any currently open HTTP connections +// to the test Server. +func (s *Server) CloseClientConnections() { + hl, ok := s.Listener.(*historyListener) + if !ok { + return + } + hl.Lock() + for _, conn := range hl.history { + conn.Close() + } + hl.Unlock() +} + +// waitGroupHandler wraps a handler, incrementing and decrementing a +// sync.WaitGroup on each request, to enable Server.Close to block +// until outstanding requests are finished. +type waitGroupHandler struct { + s *Server + h http.Handler // non-nil +} + +func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.s.wg.Add(1) + defer h.s.wg.Done() // a defer, in case ServeHTTP below panics + h.h.ServeHTTP(w, r) +} + +// localhostCert is a PEM-encoded TLS cert with SAN IPs +// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end +// of ASN.1 time). +// generated from src/crypto/tls: +// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var localhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD +bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj +bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa +IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA +AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud +EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA +AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk +Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA== +-----END CERTIFICATE-----`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0 +0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV +NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d +AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW +MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD +EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA +1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE= +-----END RSA PRIVATE KEY-----`) diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go new file mode 100644 index 000000000..500a9f0b8 --- /dev/null +++ b/src/net/http/httptest/server_test.go @@ -0,0 +1,29 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "io/ioutil" + "net/http" + "testing" +) + +func TestServer(t *testing.T) { + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + defer ts.Close() + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } +} |