From 17fb1d2cc13f0f40b9e0341d1ee1b284f30cf58e Mon Sep 17 00:00:00 2001 From: Fumitoshi Ukai Date: Thu, 18 Feb 2010 18:32:40 -0800 Subject: http: avoid server crash on malformed client request R=r, rsc CC=golang-dev http://codereview.appspot.com/206079 Committer: Russ Cox --- src/pkg/websocket/server.go | 30 ++++++++++++++++++++++-------- src/pkg/websocket/websocket_test.go | 15 +++++++++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) (limited to 'src/pkg') diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go index 43c2a7c8d..0ccb31e8a 100644 --- a/src/pkg/websocket/server.go +++ b/src/pkg/websocket/server.go @@ -38,20 +38,34 @@ type Handler func(*Conn) // ServeHTTP implements the http.Handler interface for a Web Socket. func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { - if req.Method != "GET" || req.Proto != "HTTP/1.1" || - req.Header["Upgrade"] != "WebSocket" || - req.Header["Connection"] != "Upgrade" { - c.WriteHeader(http.StatusNotFound) - io.WriteString(c, "must use websocket to connect here") + if req.Method != "GET" || req.Proto != "HTTP/1.1" { + c.WriteHeader(http.StatusBadRequest) + io.WriteString(c, "Unexpected request") return } + if v, present := req.Header["Upgrade"]; !present || v != "WebSocket" { + c.WriteHeader(http.StatusBadRequest) + io.WriteString(c, "missing Upgrade: WebSocket header") + return + } + if v, present := req.Header["Connection"]; !present || v != "Upgrade" { + c.WriteHeader(http.StatusBadRequest) + io.WriteString(c, "missing Connection: Upgrade header") + return + } + origin, present := req.Header["Origin"] + if !present { + c.WriteHeader(http.StatusBadRequest) + io.WriteString(c, "missing Origin header") + return + } + rwc, buf, err := c.Hijack() if err != nil { panic("Hijack failed: ", err.String()) return } defer rwc.Close() - origin := req.Header["Origin"] location := "ws://" + req.Host + req.URL.Path // TODO(ukai): verify origin,location,protocol. @@ -61,9 +75,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { buf.WriteString("Connection: Upgrade\r\n") buf.WriteString("WebSocket-Origin: " + origin + "\r\n") buf.WriteString("WebSocket-Location: " + location + "\r\n") - protocol := "" + protocol, present := req.Header["Websocket-Protocol"] // canonical header key of WebSocket-Protocol. - if protocol, found := req.Header["Websocket-Protocol"]; found { + if present { buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n") } buf.WriteString("\r\n") diff --git a/src/pkg/websocket/websocket_test.go b/src/pkg/websocket/websocket_test.go index c62604621..c15c43538 100644 --- a/src/pkg/websocket/websocket_test.go +++ b/src/pkg/websocket/websocket_test.go @@ -6,6 +6,7 @@ package websocket import ( "bytes" + "fmt" "http" "io" "log" @@ -59,3 +60,17 @@ func TestEcho(t *testing.T) { } ws.Close() } + +func TestHTTP(t *testing.T) { + once.Do(startServer) + + r, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) + if err != nil { + t.Errorf("Get: error %v", err) + return + } + if r.StatusCode != http.StatusBadRequest { + t.Errorf("Get: got status %d", r.StatusCode) + return + } +} -- cgit v1.2.3