diff options
Diffstat (limited to 'src/pkg/websocket/server.go')
-rw-r--r-- | src/pkg/websocket/server.go | 36 |
1 files changed, 18 insertions, 18 deletions
diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go index dd797f24e..37149f044 100644 --- a/src/pkg/websocket/server.go +++ b/src/pkg/websocket/server.go @@ -58,7 +58,7 @@ func getKeyNumber(s string) (r uint32) { // ServeHTTP implements the http.Handler interface for a Web Socket func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - rwc, buf, err := w.Hijack() + rwc, buf, err := w.(http.Hijacker).Hijack() if err != nil { panic("Hijack failed: " + err.String()) return @@ -73,23 +73,23 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } // HTTP version can be safely ignored. - if strings.ToLower(req.Header["Upgrade"]) != "websocket" || - strings.ToLower(req.Header["Connection"]) != "upgrade" { + if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || + strings.ToLower(req.Header.Get("Connection")) != "upgrade" { return } // TODO(ukai): check Host - origin, found := req.Header["Origin"] - if !found { + origin := req.Header.Get("Origin") + if origin == "" { return } - key1, found := req.Header["Sec-Websocket-Key1"] - if !found { + key1 := req.Header.Get("Sec-Websocket-Key1") + if key1 == "" { return } - key2, found := req.Header["Sec-Websocket-Key2"] - if !found { + key2 := req.Header.Get("Sec-Websocket-Key2") + if key2 == "" { return } key3 := make([]byte, 8) @@ -138,8 +138,8 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { buf.WriteString("Connection: Upgrade\r\n") buf.WriteString("Sec-WebSocket-Location: " + location + "\r\n") buf.WriteString("Sec-WebSocket-Origin: " + origin + "\r\n") - protocol, found := req.Header["Sec-Websocket-Protocol"] - if found { + protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) + if protocol != "" { buf.WriteString("Sec-WebSocket-Protocol: " + protocol + "\r\n") } // Step 12. send CRLF. @@ -167,24 +167,24 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { io.WriteString(w, "Unexpected request") return } - if req.Header["Upgrade"] != "WebSocket" { + if req.Header.Get("Upgrade") != "WebSocket" { w.WriteHeader(http.StatusBadRequest) io.WriteString(w, "missing Upgrade: WebSocket header") return } - if req.Header["Connection"] != "Upgrade" { + if req.Header.Get("Connection") != "Upgrade" { w.WriteHeader(http.StatusBadRequest) io.WriteString(w, "missing Connection: Upgrade header") return } - origin, found := req.Header["Origin"] - if !found { + origin := strings.TrimSpace(req.Header.Get("Origin")) + if origin == "" { w.WriteHeader(http.StatusBadRequest) io.WriteString(w, "missing Origin header") return } - rwc, buf, err := w.Hijack() + rwc, buf, err := w.(http.Hijacker).Hijack() if err != nil { panic("Hijack failed: " + err.String()) return @@ -205,9 +205,9 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { buf.WriteString("Connection: Upgrade\r\n") buf.WriteString("WebSocket-Origin: " + origin + "\r\n") buf.WriteString("WebSocket-Location: " + location + "\r\n") - protocol, found := req.Header["Websocket-Protocol"] + protocol := strings.TrimSpace(req.Header.Get("Websocket-Protocol")) // canonical header key of WebSocket-Protocol. - if found { + if protocol != "" { buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n") } buf.WriteString("\r\n") |