summaryrefslogtreecommitdiff
path: root/src/pkg/websocket
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/websocket')
-rw-r--r--src/pkg/websocket/client.go20
-rw-r--r--src/pkg/websocket/server.go36
-rw-r--r--src/pkg/websocket/websocket_test.go11
3 files changed, 32 insertions, 35 deletions
diff --git a/src/pkg/websocket/client.go b/src/pkg/websocket/client.go
index 091345944..d8a7aa0a2 100644
--- a/src/pkg/websocket/client.go
+++ b/src/pkg/websocket/client.go
@@ -245,20 +245,20 @@ func handshake(resourceName, host, origin, location, protocol string, br *bufio.
}
// Step 41. check websocket headers.
- if resp.Header["Upgrade"] != "WebSocket" ||
- strings.ToLower(resp.Header["Connection"]) != "upgrade" {
+ if resp.Header.Get("Upgrade") != "WebSocket" ||
+ strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
return ErrBadUpgrade
}
- if resp.Header["Sec-Websocket-Origin"] != origin {
+ if resp.Header.Get("Sec-Websocket-Origin") != origin {
return ErrBadWebSocketOrigin
}
- if resp.Header["Sec-Websocket-Location"] != location {
+ if resp.Header.Get("Sec-Websocket-Location") != location {
return ErrBadWebSocketLocation
}
- if protocol != "" && resp.Header["Sec-Websocket-Protocol"] != protocol {
+ if protocol != "" && resp.Header.Get("Sec-Websocket-Protocol") != protocol {
return ErrBadWebSocketProtocol
}
@@ -304,17 +304,17 @@ func draft75handshake(resourceName, host, origin, location, protocol string, br
if resp.Status != "101 Web Socket Protocol Handshake" {
return ErrBadStatus
}
- if resp.Header["Upgrade"] != "WebSocket" ||
- resp.Header["Connection"] != "Upgrade" {
+ if resp.Header.Get("Upgrade") != "WebSocket" ||
+ resp.Header.Get("Connection") != "Upgrade" {
return ErrBadUpgrade
}
- if resp.Header["Websocket-Origin"] != origin {
+ if resp.Header.Get("Websocket-Origin") != origin {
return ErrBadWebSocketOrigin
}
- if resp.Header["Websocket-Location"] != location {
+ if resp.Header.Get("Websocket-Location") != location {
return ErrBadWebSocketLocation
}
- if protocol != "" && resp.Header["Websocket-Protocol"] != protocol {
+ if protocol != "" && resp.Header.Get("Websocket-Protocol") != protocol {
return ErrBadWebSocketProtocol
}
return
diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go
index dd797f24e..37149f044 100644
--- a/src/pkg/websocket/server.go
+++ b/src/pkg/websocket/server.go
@@ -58,7 +58,7 @@ func getKeyNumber(s string) (r uint32) {
// ServeHTTP implements the http.Handler interface for a Web Socket
func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- rwc, buf, err := w.Hijack()
+ rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.String())
return
@@ -73,23 +73,23 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
// HTTP version can be safely ignored.
- if strings.ToLower(req.Header["Upgrade"]) != "websocket" ||
- strings.ToLower(req.Header["Connection"]) != "upgrade" {
+ if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
+ strings.ToLower(req.Header.Get("Connection")) != "upgrade" {
return
}
// TODO(ukai): check Host
- origin, found := req.Header["Origin"]
- if !found {
+ origin := req.Header.Get("Origin")
+ if origin == "" {
return
}
- key1, found := req.Header["Sec-Websocket-Key1"]
- if !found {
+ key1 := req.Header.Get("Sec-Websocket-Key1")
+ if key1 == "" {
return
}
- key2, found := req.Header["Sec-Websocket-Key2"]
- if !found {
+ key2 := req.Header.Get("Sec-Websocket-Key2")
+ if key2 == "" {
return
}
key3 := make([]byte, 8)
@@ -138,8 +138,8 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("Sec-WebSocket-Location: " + location + "\r\n")
buf.WriteString("Sec-WebSocket-Origin: " + origin + "\r\n")
- protocol, found := req.Header["Sec-Websocket-Protocol"]
- if found {
+ protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
+ if protocol != "" {
buf.WriteString("Sec-WebSocket-Protocol: " + protocol + "\r\n")
}
// Step 12. send CRLF.
@@ -167,24 +167,24 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
io.WriteString(w, "Unexpected request")
return
}
- if req.Header["Upgrade"] != "WebSocket" {
+ if req.Header.Get("Upgrade") != "WebSocket" {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "missing Upgrade: WebSocket header")
return
}
- if req.Header["Connection"] != "Upgrade" {
+ if req.Header.Get("Connection") != "Upgrade" {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "missing Connection: Upgrade header")
return
}
- origin, found := req.Header["Origin"]
- if !found {
+ origin := strings.TrimSpace(req.Header.Get("Origin"))
+ if origin == "" {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "missing Origin header")
return
}
- rwc, buf, err := w.Hijack()
+ rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.String())
return
@@ -205,9 +205,9 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("WebSocket-Origin: " + origin + "\r\n")
buf.WriteString("WebSocket-Location: " + location + "\r\n")
- protocol, found := req.Header["Websocket-Protocol"]
+ protocol := strings.TrimSpace(req.Header.Get("Websocket-Protocol"))
// canonical header key of WebSocket-Protocol.
- if found {
+ if protocol != "" {
buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
}
buf.WriteString("\r\n")
diff --git a/src/pkg/websocket/websocket_test.go b/src/pkg/websocket/websocket_test.go
index 204a9de1e..14d708a3b 100644
--- a/src/pkg/websocket/websocket_test.go
+++ b/src/pkg/websocket/websocket_test.go
@@ -9,6 +9,7 @@ import (
"bytes"
"fmt"
"http"
+ "http/httptest"
"io"
"log"
"net"
@@ -22,15 +23,11 @@ var once sync.Once
func echoServer(ws *Conn) { io.Copy(ws, ws) }
func startServer() {
- l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
- if e != nil {
- log.Fatalf("net.Listen tcp :0 %v", e)
- }
- serverAddr = l.Addr().String()
- log.Print("Test WebSocket server listening on ", serverAddr)
http.Handle("/echo", Handler(echoServer))
http.Handle("/echoDraft75", Draft75Handler(echoServer))
- go http.Serve(l, nil)
+ server := httptest.NewServer(nil)
+ serverAddr = server.Listener.Addr().String()
+ log.Print("Test WebSocket server listening on ", serverAddr)
}
// Test the getChallengeResponse function with values from section