diff options
author | Ondřej Surý <ondrej@sury.org> | 2011-01-17 12:40:45 +0100 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2011-01-17 12:40:45 +0100 |
commit | 3e45412327a2654a77944249962b3652e6142299 (patch) | |
tree | bc3bf69452afa055423cbe0c5cfa8ca357df6ccf /src/pkg/websocket | |
parent | c533680039762cacbc37db8dc7eed074c3e497be (diff) | |
download | golang-upstream/2011.01.12.tar.gz |
Imported Upstream version 2011.01.12upstream/2011.01.12
Diffstat (limited to 'src/pkg/websocket')
-rw-r--r-- | src/pkg/websocket/Makefile | 2 | ||||
-rw-r--r-- | src/pkg/websocket/client.go | 113 | ||||
-rw-r--r-- | src/pkg/websocket/server.go | 75 | ||||
-rw-r--r-- | src/pkg/websocket/websocket.go | 92 | ||||
-rw-r--r-- | src/pkg/websocket/websocket_test.go | 133 |
5 files changed, 292 insertions, 123 deletions
diff --git a/src/pkg/websocket/Makefile b/src/pkg/websocket/Makefile index 145d8f429..6d3c9cbd1 100644 --- a/src/pkg/websocket/Makefile +++ b/src/pkg/websocket/Makefile @@ -1,4 +1,4 @@ -include ../../Make.$(GOARCH) +include ../../Make.inc TARG=websocket GOFILES=\ diff --git a/src/pkg/websocket/client.go b/src/pkg/websocket/client.go index 2966450a6..091345944 100644 --- a/src/pkg/websocket/client.go +++ b/src/pkg/websocket/client.go @@ -5,11 +5,10 @@ package websocket import ( - "encoding/binary" "bufio" "bytes" "container/vector" - "crypto/md5" + "crypto/tls" "fmt" "http" "io" @@ -24,6 +23,7 @@ type ProtocolError struct { } var ( + ErrBadScheme = os.ErrorString("bad scheme") ErrBadStatus = &ProtocolError{"bad status"} ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"} ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"} @@ -33,6 +33,17 @@ var ( secKeyRandomChars [0x30 - 0x21 + 0x7F - 0x3A]byte ) +type DialError struct { + URL string + Protocol string + Origin string + Error os.Error +} + +func (e *DialError) String() string { + return "websocket.Dial " + e.URL + ": " + e.Error.String() +} + func init() { i := 0 for ch := byte(0x21); ch < 0x30; ch++ { @@ -61,8 +72,9 @@ func newClient(resourceName, host, origin, location, protocol string, rwc io.Rea } /* - Dial opens a new client connection to a Web Socket. - A trivial example client is: +Dial opens a new client connection to a Web Socket. + +A trivial example client: package main @@ -87,21 +99,40 @@ func newClient(resourceName, host, origin, location, protocol string, rwc io.Rea } */ func Dial(url, protocol, origin string) (ws *Conn, err os.Error) { + var client net.Conn + parsedUrl, err := http.ParseURL(url) if err != nil { - return + goto Error + } + + switch parsedUrl.Scheme { + case "ws": + client, err = net.Dial("tcp", "", parsedUrl.Host) + + case "wss": + client, err = tls.Dial("tcp", "", parsedUrl.Host, nil) + + default: + err = ErrBadScheme } - client, err := net.Dial("tcp", "", parsedUrl.Host) if err != nil { - return + goto Error } - return newClient(parsedUrl.RawPath, parsedUrl.Host, origin, url, protocol, client, handshake) + + ws, err = newClient(parsedUrl.RawPath, parsedUrl.Host, origin, url, protocol, client, handshake) + if err != nil { + goto Error + } + return + +Error: + return nil, &DialError{url, protocol, origin, err} } /* - Generates handshake key as described in 4.1 Opening handshake - step 16 to 22. - cf. http://www.whatwg.org/specs/web-socket-protocol/ +Generates handshake key as described in 4.1 Opening handshake step 16 to 22. +cf. http://www.whatwg.org/specs/web-socket-protocol/ */ func generateKeyNumber() (key string, number uint32) { // 16. Let /spaces_n/ be a random integer from 1 to 12 inclusive. @@ -123,14 +154,7 @@ func generateKeyNumber() (key string, number uint32) { // to U+0039 DIGIT NINE (9). key = fmt.Sprintf("%d", product) - // 21. Insert /spaces_n/ U+0020 SPACE characters into /key_n/ at random - // posisions. - for i := 0; i < spaces; i++ { - pos := rand.Intn(len(key)-1) + 1 - key = key[0:pos] + " " + key[pos:] - } - - // 22. Insert between one and twelve random characters from the ranges + // 21. Insert between one and twelve random characters from the ranges // U+0021 to U+002F and U+003A to U+007E into /key_n/ at random // positions. n := rand.Intn(12) + 1 @@ -139,13 +163,20 @@ func generateKeyNumber() (key string, number uint32) { ch := secKeyRandomChars[rand.Intn(len(secKeyRandomChars))] key = key[0:pos] + string(ch) + key[pos:] } + + // 22. Insert /spaces_n/ U+0020 SPACE characters into /key_n/ at random + // positions other than the start or end of the string. + for i := 0; i < spaces; i++ { + pos := rand.Intn(len(key)-1) + 1 + key = key[0:pos] + " " + key[pos:] + } + return } /* - Generates handshake key_3 as described in 4.1 Opening handshake - step 26. - cf. http://www.whatwg.org/specs/web-socket-protocol/ +Generates handshake key_3 as described in 4.1 Opening handshake step 26. +cf. http://www.whatwg.org/specs/web-socket-protocol/ */ func generateKey3() (key []byte) { // 26. Let /key3/ be a string consisting of eight random bytes (or @@ -158,35 +189,9 @@ func generateKey3() (key []byte) { } /* - Gets expected from challenge as described in 4.1 Opening handshake - Step 42 to 43. - cf. http://www.whatwg.org/specs/web-socket-protocol/ -*/ -func getExpectedForChallenge(number1, number2 uint32, key3 []byte) (expected []byte, err os.Error) { - // 41. Let /challenge/ be the concatenation of /number_1/, expressed - // a big-endian 32 bit integer, /number_2/, expressed in a big- - // endian 32 bit integer, and the eight bytes of /key_3/ in the - // order they were sent to the wire. - challenge := make([]byte, 16) - challengeBuf := bytes.NewBuffer(challenge) - binary.Write(challengeBuf, binary.BigEndian, number1) - binary.Write(challengeBuf, binary.BigEndian, number2) - copy(challenge[8:], key3) - - // 42. Let /expected/ be the MD5 fingerprint of /challenge/ as a big- - // endian 128 bit string. - h := md5.New() - if _, err = h.Write(challenge); err != nil { - return - } - expected = h.Sum() - return -} - -/* - Web Socket protocol handshake based on - http://www.whatwg.org/specs/web-socket-protocol/ - (draft of http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol) +Web Socket protocol handshake based on +http://www.whatwg.org/specs/web-socket-protocol/ +(draft of http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol) */ func handshake(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { // 4.1. Opening handshake. @@ -258,7 +263,7 @@ func handshake(resourceName, host, origin, location, protocol string, br *bufio. } // Step 42-43. get expected data from challange data. - expected, err := getExpectedForChallenge(number1, number2, key3) + expected, err := getChallengeResponse(number1, number2, key3) if err != nil { return err } @@ -278,8 +283,8 @@ func handshake(resourceName, host, origin, location, protocol string, br *bufio. } /* - Handhake described in (soon obsolete) - draft-hixie-thewebsocket-protocol-75. +Handhake described in (soon obsolete) +draft-hixie-thewebsocket-protocol-75. */ func draft75handshake(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { bw.WriteString("GET " + resourceName + " HTTP/1.1\r\n") diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go index 00b537e27..dd797f24e 100644 --- a/src/pkg/websocket/server.go +++ b/src/pkg/websocket/server.go @@ -5,17 +5,15 @@ package websocket import ( - "bytes" - "crypto/md5" - "encoding/binary" "http" "io" "strings" ) /* - Handler is an interface to a WebSocket. - A trivial example server is: +Handler is an interface to a WebSocket. + +A trivial example server: package main @@ -41,8 +39,8 @@ import ( type Handler func(*Conn) /* - Gets key number from Sec-WebSocket-Key<n>: field as described - in 5.2 Sending the server's opening handshake, 4. +Gets key number from Sec-WebSocket-Key<n>: field as described +in 5.2 Sending the server's opening handshake, 4. */ func getKeyNumber(s string) (r uint32) { // 4. Let /key-number_n/ be the digits (characters in the range @@ -59,8 +57,8 @@ func getKeyNumber(s string) (r uint32) { } // ServeHTTP implements the http.Handler interface for a Web Socket -func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { - rwc, buf, err := c.Hijack() +func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + rwc, buf, err := w.Hijack() if err != nil { panic("Hijack failed: " + err.String()) return @@ -99,7 +97,12 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { return } - location := "ws://" + req.Host + req.URL.RawPath + var location string + if w.UsingTLS() { + location = "wss://" + req.Host + req.URL.RawPath + } else { + location = "ws://" + req.Host + req.URL.RawPath + } // Step 4. get key number in Sec-WebSocket-Key<n> fields. keyNumber1 := getKeyNumber(key1) @@ -122,25 +125,11 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { part2 := keyNumber2 / space2 // Step 8. let challenge to be concatination of part1, part2 and key3. - challenge := make([]byte, 16) - challengeBuf := bytes.NewBuffer(challenge) - err = binary.Write(challengeBuf, binary.BigEndian, part1) - if err != nil { - return - } - err = binary.Write(challengeBuf, binary.BigEndian, part2) - if err != nil { - return - } - if n := copy(challenge[8:], key3); n != 8 { - return - } // Step 9. get MD5 fingerprint of challenge. - h := md5.New() - if _, err = h.Write(challenge); err != nil { + response, err := getChallengeResponse(part1, part2, key3) + if err != nil { return } - response := h.Sum() // Step 10. send response status line. buf.WriteString("HTTP/1.1 101 WebSocket Protocol Handshake\r\n") @@ -149,7 +138,7 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { buf.WriteString("Connection: Upgrade\r\n") buf.WriteString("Sec-WebSocket-Location: " + location + "\r\n") buf.WriteString("Sec-WebSocket-Origin: " + origin + "\r\n") - protocol, found := req.Header["Sec-WebSocket-Protocol"] + protocol, found := req.Header["Sec-Websocket-Protocol"] if found { buf.WriteString("Sec-WebSocket-Protocol: " + protocol + "\r\n") } @@ -166,42 +155,48 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { /* - Draft75Handler is an interface to a WebSocket based on - (soon obsolete) draft-hixie-thewebsocketprotocol-75. +Draft75Handler is an interface to a WebSocket based on the +(soon obsolete) draft-hixie-thewebsocketprotocol-75. */ type Draft75Handler func(*Conn) // ServeHTTP implements the http.Handler interface for a Web Socket. -func (f Draft75Handler) ServeHTTP(c *http.Conn, req *http.Request) { +func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.Method != "GET" || req.Proto != "HTTP/1.1" { - c.WriteHeader(http.StatusBadRequest) - io.WriteString(c, "Unexpected request") + w.WriteHeader(http.StatusBadRequest) + io.WriteString(w, "Unexpected request") return } if req.Header["Upgrade"] != "WebSocket" { - c.WriteHeader(http.StatusBadRequest) - io.WriteString(c, "missing Upgrade: WebSocket header") + w.WriteHeader(http.StatusBadRequest) + io.WriteString(w, "missing Upgrade: WebSocket header") return } if req.Header["Connection"] != "Upgrade" { - c.WriteHeader(http.StatusBadRequest) - io.WriteString(c, "missing Connection: Upgrade header") + w.WriteHeader(http.StatusBadRequest) + io.WriteString(w, "missing Connection: Upgrade header") return } origin, found := req.Header["Origin"] if !found { - c.WriteHeader(http.StatusBadRequest) - io.WriteString(c, "missing Origin header") + w.WriteHeader(http.StatusBadRequest) + io.WriteString(w, "missing Origin header") return } - rwc, buf, err := c.Hijack() + rwc, buf, err := w.Hijack() if err != nil { panic("Hijack failed: " + err.String()) return } defer rwc.Close() - location := "ws://" + req.Host + req.URL.RawPath + + var location string + if w.UsingTLS() { + location = "wss://" + req.Host + req.URL.RawPath + } else { + location = "ws://" + req.Host + req.URL.RawPath + } // TODO(ukai): verify origin,location,protocol. diff --git a/src/pkg/websocket/websocket.go b/src/pkg/websocket/websocket.go index bcb42f508..d5996abe1 100644 --- a/src/pkg/websocket/websocket.go +++ b/src/pkg/websocket/websocket.go @@ -11,6 +11,8 @@ package websocket import ( "bufio" + "crypto/md5" + "encoding/binary" "io" "net" "os" @@ -25,6 +27,13 @@ func (addr WebSocketAddr) Network() string { return "websocket" } // String returns the network address for a Web Socket. func (addr WebSocketAddr) String() string { return string(addr) } +const ( + stateFrameByte = iota + stateFrameLength + stateFrameData + stateFrameTextData +) + // Conn is a channel to communicate to a Web Socket. // It implements the net.Conn interface. type Conn struct { @@ -37,6 +46,10 @@ type Conn struct { buf *bufio.ReadWriter rwc io.ReadWriteCloser + + // It holds text data in previous Read() that failed with small buffer. + data []byte + reading bool } // newConn creates a new Web Socket. @@ -46,60 +59,66 @@ func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.Re bw := bufio.NewWriter(rwc) buf = bufio.NewReadWriter(br, bw) } - ws := &Conn{origin, location, protocol, buf, rwc} + ws := &Conn{Origin: origin, Location: location, Protocol: protocol, buf: buf, rwc: rwc} return ws } // Read implements the io.Reader interface for a Conn. func (ws *Conn) Read(msg []byte) (n int, err os.Error) { - for { - frameByte, err := ws.buf.ReadByte() +Frame: + for !ws.reading && len(ws.data) == 0 { + // Beginning of frame, possibly. + b, err := ws.buf.ReadByte() if err != nil { - return n, err + return 0, err } - if (frameByte & 0x80) == 0x80 { + if b&0x80 == 0x80 { + // Skip length frame. length := 0 for { c, err := ws.buf.ReadByte() if err != nil { - return n, err + return 0, err } length = length*128 + int(c&0x7f) - if (c & 0x80) == 0 { + if c&0x80 == 0 { break } } for length > 0 { _, err := ws.buf.ReadByte() if err != nil { - return n, err + return 0, err } - length-- } - } else { + continue Frame + } + // In text mode + if b != 0 { + // Skip this frame for { c, err := ws.buf.ReadByte() if err != nil { - return n, err + return 0, err } if c == '\xff' { - return n, err - } - if frameByte == 0 { - if n+1 <= cap(msg) { - msg = msg[0 : n+1] - } - msg[n] = c - n++ - } - if n >= cap(msg) { - return n, os.E2BIG + break } } + continue Frame } + ws.reading = true } - - panic("unreachable") + if len(ws.data) == 0 { + ws.data, err = ws.buf.ReadSlice('\xff') + if err == nil { + ws.reading = false + ws.data = ws.data[:len(ws.data)-1] // trim \xff + } + } + n = copy(msg, ws.data) + ws.data = ws.data[n:] + return n, err } // Write implements the io.Writer interface for a Conn. @@ -136,7 +155,7 @@ func (ws *Conn) SetReadTimeout(nsec int64) os.Error { return os.EINVAL } -// SeWritetTimeout sets the connection's network write timeout in nanoseconds. +// SetWritetTimeout sets the connection's network write timeout in nanoseconds. func (ws *Conn) SetWriteTimeout(nsec int64) os.Error { if conn, ok := ws.rwc.(net.Conn); ok { return conn.SetWriteTimeout(nsec) @@ -144,4 +163,27 @@ func (ws *Conn) SetWriteTimeout(nsec int64) os.Error { return os.EINVAL } +// getChallengeResponse computes the expected response from the +// challenge as described in section 5.1 Opening Handshake steps 42 to +// 43 of http://www.whatwg.org/specs/web-socket-protocol/ +func getChallengeResponse(number1, number2 uint32, key3 []byte) (expected []byte, err os.Error) { + // 41. Let /challenge/ be the concatenation of /number_1/, expressed + // a big-endian 32 bit integer, /number_2/, expressed in a big- + // endian 32 bit integer, and the eight bytes of /key_3/ in the + // order they were sent to the wire. + challenge := make([]byte, 16) + binary.BigEndian.PutUint32(challenge[0:], number1) + binary.BigEndian.PutUint32(challenge[4:], number2) + copy(challenge[8:], key3) + + // 42. Let /expected/ be the MD5 fingerprint of /challenge/ as a big- + // endian 128 bit string. + h := md5.New() + if _, err = h.Write(challenge); err != nil { + return + } + expected = h.Sum() + return +} + var _ net.Conn = (*Conn)(nil) // compile-time check that *Conn implements net.Conn. diff --git a/src/pkg/websocket/websocket_test.go b/src/pkg/websocket/websocket_test.go index df7e9f4da..cc4b9dc18 100644 --- a/src/pkg/websocket/websocket_test.go +++ b/src/pkg/websocket/websocket_test.go @@ -5,17 +5,19 @@ package websocket import ( + "bufio" "bytes" "fmt" "http" "io" "log" "net" - "once" + "sync" "testing" ) var serverAddr string +var once sync.Once func echoServer(ws *Conn) { io.Copy(ws, ws) } @@ -25,12 +27,31 @@ func startServer() { log.Exitf("net.Listen tcp :0 %v", e) } serverAddr = l.Addr().String() - log.Stderr("Test WebSocket server listening on ", serverAddr) + log.Print("Test WebSocket server listening on ", serverAddr) http.Handle("/echo", Handler(echoServer)) http.Handle("/echoDraft75", Draft75Handler(echoServer)) go http.Serve(l, nil) } +// Test the getChallengeResponse function with values from section +// 5.1 of the specification steps 18, 26, and 43 from +// http://www.whatwg.org/specs/web-socket-protocol/ +func TestChallenge(t *testing.T) { + var part1 uint32 = 777007543 + var part2 uint32 = 114997259 + key3 := []byte{0x47, 0x30, 0x22, 0x2D, 0x5A, 0x3F, 0x47, 0x58} + expected := []byte("0st3Rl&q-2ZU^weu") + + response, err := getChallengeResponse(part1, part2, key3) + if err != nil { + t.Errorf("getChallengeResponse: returned error %v", err) + return + } + if !bytes.Equal(expected, response) { + t.Errorf("getChallengeResponse: expected %q got %q", expected, response) + } +} + func TestEcho(t *testing.T) { once.Do(startServer) @@ -110,6 +131,23 @@ func TestWithQuery(t *testing.T) { ws.Close() } +func TestWithProtocol(t *testing.T) { + once.Do(startServer) + + client, err := net.Dial("tcp", "", serverAddr) + if err != nil { + t.Fatal("dialing", err) + } + + ws, err := newClient("/echo", "localhost", "http://localhost", + "ws://localhost/echo", "test", client, handshake) + if err != nil { + t.Errorf("WebSocket handshake: %v", err) + return + } + ws.Close() +} + func TestHTTP(t *testing.T) { once.Do(startServer) @@ -117,7 +155,7 @@ func TestHTTP(t *testing.T) { // specification, the server should abort the WebSocket connection. _, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) if err == nil { - t.Errorf("Get: unexpected success") + t.Error("Get: unexpected success") return } urlerr, ok := err.(*http.URLError) @@ -143,3 +181,92 @@ func TestHTTPDraft75(t *testing.T) { t.Errorf("Get: got status %d", r.StatusCode) } } + +func TestTrailingSpaces(t *testing.T) { + // http://code.google.com/p/go/issues/detail?id=955 + // The last runs of this create keys with trailing spaces that should not be + // generated by the client. + once.Do(startServer) + for i := 0; i < 30; i++ { + // body + _, err := Dial(fmt.Sprintf("ws://%s/echo", serverAddr), "", + "http://localhost/") + if err != nil { + panic("Dial failed: " + err.String()) + } + } +} + +func TestSmallBuffer(t *testing.T) { + // http://code.google.com/p/go/issues/detail?id=1145 + // Read should be able to handle reading a fragment of a frame. + once.Do(startServer) + + // websocket.Dial() + client, err := net.Dial("tcp", "", serverAddr) + if err != nil { + t.Fatal("dialing", err) + } + ws, err := newClient("/echo", "localhost", "http://localhost", + "ws://localhost/echo", "", client, handshake) + if err != nil { + t.Errorf("WebSocket handshake error: %v", err) + return + } + + msg := []byte("hello, world\n") + if _, err := ws.Write(msg); err != nil { + t.Errorf("Write: %v", err) + } + var small_msg = make([]byte, 8) + n, err := ws.Read(small_msg) + if err != nil { + t.Errorf("Read: %v", err) + } + if !bytes.Equal(msg[:len(small_msg)], small_msg) { + t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg) + } + var second_msg = make([]byte, len(msg)) + n, err = ws.Read(second_msg) + if err != nil { + t.Errorf("Read: %v", err) + } + second_msg = second_msg[0:n] + if !bytes.Equal(msg[len(small_msg):], second_msg) { + t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg) + } + ws.Close() + +} + +func testSkipLengthFrame(t *testing.T) { + b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'} + buf := bytes.NewBuffer(b) + br := bufio.NewReader(buf) + bw := bufio.NewWriter(buf) + ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil) + msg := make([]byte, 5) + n, err := ws.Read(msg) + if err != nil { + t.Errorf("Read: %v", err) + } + if !bytes.Equal(b[4:8], msg[0:n]) { + t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n]) + } +} + +func testSkipNoUTF8Frame(t *testing.T) { + b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'} + buf := bytes.NewBuffer(b) + br := bufio.NewReader(buf) + bw := bufio.NewWriter(buf) + ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil) + msg := make([]byte, 5) + n, err := ws.Read(msg) + if err != nil { + t.Errorf("Read: %v", err) + } + if !bytes.Equal(b[4:8], msg[0:n]) { + t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n]) + } +} |