summaryrefslogtreecommitdiff
path: root/src/pkg
diff options
context:
space:
mode:
authorFumitoshi Ukai <ukai@google.com>2010-02-18 18:32:40 -0800
committerFumitoshi Ukai <ukai@google.com>2010-02-18 18:32:40 -0800
commit17fb1d2cc13f0f40b9e0341d1ee1b284f30cf58e (patch)
tree3847afbcb18fe3174a15de2a2eca135388c5ab88 /src/pkg
parent8690a162c8b2b69c7bd81917d8ee97b29bc9cfed (diff)
downloadgolang-17fb1d2cc13f0f40b9e0341d1ee1b284f30cf58e.tar.gz
http: avoid server crash on malformed client request
R=r, rsc CC=golang-dev http://codereview.appspot.com/206079 Committer: Russ Cox <rsc@golang.org>
Diffstat (limited to 'src/pkg')
-rw-r--r--src/pkg/websocket/server.go30
-rw-r--r--src/pkg/websocket/websocket_test.go15
2 files changed, 37 insertions, 8 deletions
diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go
index 43c2a7c8d..0ccb31e8a 100644
--- a/src/pkg/websocket/server.go
+++ b/src/pkg/websocket/server.go
@@ -38,20 +38,34 @@ type Handler func(*Conn)
// ServeHTTP implements the http.Handler interface for a Web Socket.
func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
- if req.Method != "GET" || req.Proto != "HTTP/1.1" ||
- req.Header["Upgrade"] != "WebSocket" ||
- req.Header["Connection"] != "Upgrade" {
- c.WriteHeader(http.StatusNotFound)
- io.WriteString(c, "must use websocket to connect here")
+ if req.Method != "GET" || req.Proto != "HTTP/1.1" {
+ c.WriteHeader(http.StatusBadRequest)
+ io.WriteString(c, "Unexpected request")
return
}
+ if v, present := req.Header["Upgrade"]; !present || v != "WebSocket" {
+ c.WriteHeader(http.StatusBadRequest)
+ io.WriteString(c, "missing Upgrade: WebSocket header")
+ return
+ }
+ if v, present := req.Header["Connection"]; !present || v != "Upgrade" {
+ c.WriteHeader(http.StatusBadRequest)
+ io.WriteString(c, "missing Connection: Upgrade header")
+ return
+ }
+ origin, present := req.Header["Origin"]
+ if !present {
+ c.WriteHeader(http.StatusBadRequest)
+ io.WriteString(c, "missing Origin header")
+ return
+ }
+
rwc, buf, err := c.Hijack()
if err != nil {
panic("Hijack failed: ", err.String())
return
}
defer rwc.Close()
- origin := req.Header["Origin"]
location := "ws://" + req.Host + req.URL.Path
// TODO(ukai): verify origin,location,protocol.
@@ -61,9 +75,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("WebSocket-Origin: " + origin + "\r\n")
buf.WriteString("WebSocket-Location: " + location + "\r\n")
- protocol := ""
+ protocol, present := req.Header["Websocket-Protocol"]
// canonical header key of WebSocket-Protocol.
- if protocol, found := req.Header["Websocket-Protocol"]; found {
+ if present {
buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
}
buf.WriteString("\r\n")
diff --git a/src/pkg/websocket/websocket_test.go b/src/pkg/websocket/websocket_test.go
index c62604621..c15c43538 100644
--- a/src/pkg/websocket/websocket_test.go
+++ b/src/pkg/websocket/websocket_test.go
@@ -6,6 +6,7 @@ package websocket
import (
"bytes"
+ "fmt"
"http"
"io"
"log"
@@ -59,3 +60,17 @@ func TestEcho(t *testing.T) {
}
ws.Close()
}
+
+func TestHTTP(t *testing.T) {
+ once.Do(startServer)
+
+ r, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
+ if err != nil {
+ t.Errorf("Get: error %v", err)
+ return
+ }
+ if r.StatusCode != http.StatusBadRequest {
+ t.Errorf("Get: got status %d", r.StatusCode)
+ return
+ }
+}