diff options
Diffstat (limited to 'src/pkg/websocket/server.go')
| -rw-r--r-- | src/pkg/websocket/server.go | 30 | 
1 files changed, 22 insertions, 8 deletions
| diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go index 43c2a7c8d..0ccb31e8a 100644 --- a/src/pkg/websocket/server.go +++ b/src/pkg/websocket/server.go @@ -38,20 +38,34 @@ type Handler func(*Conn)  // ServeHTTP implements the http.Handler interface for a Web Socket.  func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { -	if req.Method != "GET" || req.Proto != "HTTP/1.1" || -		req.Header["Upgrade"] != "WebSocket" || -		req.Header["Connection"] != "Upgrade" { -		c.WriteHeader(http.StatusNotFound) -		io.WriteString(c, "must use websocket to connect here") +	if req.Method != "GET" || req.Proto != "HTTP/1.1" { +		c.WriteHeader(http.StatusBadRequest) +		io.WriteString(c, "Unexpected request")  		return  	} +	if v, present := req.Header["Upgrade"]; !present || v != "WebSocket" { +		c.WriteHeader(http.StatusBadRequest) +		io.WriteString(c, "missing Upgrade: WebSocket header") +		return +	} +	if v, present := req.Header["Connection"]; !present || v != "Upgrade" { +		c.WriteHeader(http.StatusBadRequest) +		io.WriteString(c, "missing Connection: Upgrade header") +		return +	} +	origin, present := req.Header["Origin"] +	if !present { +		c.WriteHeader(http.StatusBadRequest) +		io.WriteString(c, "missing Origin header") +		return +	} +  	rwc, buf, err := c.Hijack()  	if err != nil {  		panic("Hijack failed: ", err.String())  		return  	}  	defer rwc.Close() -	origin := req.Header["Origin"]  	location := "ws://" + req.Host + req.URL.Path  	// TODO(ukai): verify origin,location,protocol. @@ -61,9 +75,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {  	buf.WriteString("Connection: Upgrade\r\n")  	buf.WriteString("WebSocket-Origin: " + origin + "\r\n")  	buf.WriteString("WebSocket-Location: " + location + "\r\n") -	protocol := "" +	protocol, present := req.Header["Websocket-Protocol"]  	// canonical header key of WebSocket-Protocol. -	if protocol, found := req.Header["Websocket-Protocol"]; found { +	if present {  		buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")  	}  	buf.WriteString("\r\n") | 
