summaryrefslogtreecommitdiff
path: root/src/pkg/rpc
diff options
context:
space:
mode:
authorOndřej Surý <ondrej@sury.org>2011-01-17 12:40:45 +0100
committerOndřej Surý <ondrej@sury.org>2011-01-17 12:40:45 +0100
commit3e45412327a2654a77944249962b3652e6142299 (patch)
treebc3bf69452afa055423cbe0c5cfa8ca357df6ccf /src/pkg/rpc
parentc533680039762cacbc37db8dc7eed074c3e497be (diff)
downloadgolang-upstream/2011.01.12.tar.gz
Imported Upstream version 2011.01.12upstream/2011.01.12
Diffstat (limited to 'src/pkg/rpc')
-rw-r--r--src/pkg/rpc/Makefile2
-rw-r--r--src/pkg/rpc/client.go21
-rw-r--r--src/pkg/rpc/debug.go36
-rw-r--r--src/pkg/rpc/jsonrpc/Makefile2
-rw-r--r--src/pkg/rpc/jsonrpc/all_test.go13
-rw-r--r--src/pkg/rpc/jsonrpc/client.go17
-rw-r--r--src/pkg/rpc/jsonrpc/server.go14
-rw-r--r--src/pkg/rpc/server.go191
-rw-r--r--src/pkg/rpc/server_test.go126
9 files changed, 308 insertions, 114 deletions
diff --git a/src/pkg/rpc/Makefile b/src/pkg/rpc/Makefile
index 4757b3aae..191b10d05 100644
--- a/src/pkg/rpc/Makefile
+++ b/src/pkg/rpc/Makefile
@@ -2,7 +2,7 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
-include ../../Make.$(GOARCH)
+include ../../Make.inc
TARG=rpc
GOFILES=\
diff --git a/src/pkg/rpc/client.go b/src/pkg/rpc/client.go
index d742d099f..601c49715 100644
--- a/src/pkg/rpc/client.go
+++ b/src/pkg/rpc/client.go
@@ -69,12 +69,12 @@ func (client *Client) send(c *Call) {
// Encode and send the request.
request := new(Request)
client.sending.Lock()
+ defer client.sending.Unlock()
request.Seq = c.seq
request.ServiceMethod = c.ServiceMethod
if err := client.codec.WriteRequest(request, c.Args); err != nil {
panic("rpc: client encode error: " + err.String())
}
- client.sending.Unlock()
}
func (client *Client) input() {
@@ -94,10 +94,12 @@ func (client *Client) input() {
client.pending[seq] = c, false
client.mutex.Unlock()
err = client.codec.ReadResponseBody(c.Reply)
- // Empty strings should turn into nil os.Errors
if response.Error != "" {
c.Error = os.ErrorString(response.Error)
+ } else if err != nil {
+ c.Error = err
} else {
+ // Empty strings should turn into nil os.Errors
c.Error = nil
}
// We don't want to block here. It is the caller's responsibility to make
@@ -113,7 +115,7 @@ func (client *Client) input() {
}
client.mutex.Unlock()
if err != os.EOF || !client.closing {
- log.Stderr("rpc: client protocol error:", err)
+ log.Println("rpc: client protocol error:", err)
}
}
@@ -160,14 +162,21 @@ func (c *gobClientCodec) Close() os.Error {
}
-// DialHTTP connects to an HTTP RPC server at the specified network address.
+// DialHTTP connects to an HTTP RPC server at the specified network address
+// listening on the default HTTP RPC path.
func DialHTTP(network, address string) (*Client, os.Error) {
+ return DialHTTPPath(network, address, DefaultRPCPath)
+}
+
+// DialHTTPPath connects to an HTTP RPC server
+// at the specified network address and path.
+func DialHTTPPath(network, address, path string) (*Client, os.Error) {
var err os.Error
conn, err := net.Dial(network, "", address)
if err != nil {
return nil, err
}
- io.WriteString(conn, "CONNECT "+rpcPath+" HTTP/1.0\n\n")
+ io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
// Require successful HTTP response
// before switching to RPC protocol.
@@ -218,7 +227,7 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
// RPCs that will be using that channel. If the channel
// is totally unbuffered, it's best not to run at all.
if cap(done) == 0 {
- log.Crash("rpc: done channel is unbuffered")
+ log.Panic("rpc: done channel is unbuffered")
}
}
c.Done = done
diff --git a/src/pkg/rpc/debug.go b/src/pkg/rpc/debug.go
index 638584f49..44b32e04b 100644
--- a/src/pkg/rpc/debug.go
+++ b/src/pkg/rpc/debug.go
@@ -21,14 +21,14 @@ const debugText = `<html>
<title>Services</title>
{.repeated section @}
<hr>
- Service {name}
+ Service {Name}
<hr>
<table>
<th align=center>Method</th><th align=center>Calls</th>
- {.repeated section meth}
+ {.repeated section Method}
<tr>
- <td align=left font=fixed>{name}({m.argType}, {m.replyType}) os.Error</td>
- <td align=center>{m.numCalls}</td>
+ <td align=left font=fixed>{Name}({Type.ArgType}, {Type.ReplyType}) os.Error</td>
+ <td align=center>{Type.NumCalls}</td>
</tr>
{.end}
</table>
@@ -39,30 +39,34 @@ const debugText = `<html>
var debug = template.MustParse(debugText, nil)
type debugMethod struct {
- m *methodType
- name string
+ Type *methodType
+ Name string
}
type methodArray []debugMethod
type debugService struct {
- s *service
- name string
- meth methodArray
+ Service *service
+ Name string
+ Method methodArray
}
type serviceArray []debugService
func (s serviceArray) Len() int { return len(s) }
-func (s serviceArray) Less(i, j int) bool { return s[i].name < s[j].name }
+func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name }
func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (m methodArray) Len() int { return len(m) }
-func (m methodArray) Less(i, j int) bool { return m[i].name < m[j].name }
+func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name }
func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
+type debugHTTP struct {
+ *Server
+}
+
// Runs at /debug/rpc
-func debugHTTP(c *http.Conn, req *http.Request) {
+func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Build a sorted version of the data.
var services = make(serviceArray, len(server.serviceMap))
i := 0
@@ -71,16 +75,16 @@ func debugHTTP(c *http.Conn, req *http.Request) {
services[i] = debugService{service, sname, make(methodArray, len(service.method))}
j := 0
for mname, method := range service.method {
- services[i].meth[j] = debugMethod{method, mname}
+ services[i].Method[j] = debugMethod{method, mname}
j++
}
- sort.Sort(services[i].meth)
+ sort.Sort(services[i].Method)
i++
}
server.Unlock()
sort.Sort(services)
- err := debug.Execute(services, c)
+ err := debug.Execute(services, w)
if err != nil {
- fmt.Fprintln(c, "rpc: error executing template:", err.String())
+ fmt.Fprintln(w, "rpc: error executing template:", err.String())
}
}
diff --git a/src/pkg/rpc/jsonrpc/Makefile b/src/pkg/rpc/jsonrpc/Makefile
index 1a4fd2e92..b9a1ac2f7 100644
--- a/src/pkg/rpc/jsonrpc/Makefile
+++ b/src/pkg/rpc/jsonrpc/Makefile
@@ -2,7 +2,7 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
-include ../../../Make.$(GOARCH)
+include ../../../Make.inc
TARG=rpc/jsonrpc
GOFILES=\
diff --git a/src/pkg/rpc/jsonrpc/all_test.go b/src/pkg/rpc/jsonrpc/all_test.go
index e94c594da..764ee7ff3 100644
--- a/src/pkg/rpc/jsonrpc/all_test.go
+++ b/src/pkg/rpc/jsonrpc/all_test.go
@@ -53,7 +53,7 @@ func TestServer(t *testing.T) {
type addResp struct {
Id interface{} "id"
Result Reply "result"
- Error string "error"
+ Error interface{} "error"
}
cli, srv := net.Pipe()
@@ -69,7 +69,7 @@ func TestServer(t *testing.T) {
if err != nil {
t.Fatalf("Decode: %s", err)
}
- if resp.Error != "" {
+ if resp.Error != nil {
t.Fatalf("resp.Error: %s", resp.Error)
}
if resp.Id.(string) != string(i) {
@@ -79,6 +79,15 @@ func TestServer(t *testing.T) {
t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
}
}
+
+ fmt.Fprintf(cli, "{}\n")
+ var resp addResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after empty: %s", err)
+ }
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
}
func TestClient(t *testing.T) {
diff --git a/src/pkg/rpc/jsonrpc/client.go b/src/pkg/rpc/jsonrpc/client.go
index ed2b4ed37..dcaa69f9d 100644
--- a/src/pkg/rpc/jsonrpc/client.go
+++ b/src/pkg/rpc/jsonrpc/client.go
@@ -7,6 +7,7 @@
package jsonrpc
import (
+ "fmt"
"io"
"json"
"net"
@@ -61,13 +62,13 @@ func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) os.Error {
type clientResponse struct {
Id uint64 "id"
Result *json.RawMessage "result"
- Error string "error"
+ Error interface{} "error"
}
func (r *clientResponse) reset() {
r.Id = 0
r.Result = nil
- r.Error = ""
+ r.Error = nil
}
func (c *clientCodec) ReadResponseHeader(r *rpc.Response) os.Error {
@@ -81,8 +82,18 @@ func (c *clientCodec) ReadResponseHeader(r *rpc.Response) os.Error {
c.pending[c.resp.Id] = "", false
c.mutex.Unlock()
+ r.Error = ""
r.Seq = c.resp.Id
- r.Error = c.resp.Error
+ if c.resp.Error != nil {
+ x, ok := c.resp.Error.(string)
+ if !ok {
+ return fmt.Errorf("invalid error %v", c.resp.Error)
+ }
+ if x == "" {
+ x = "unspecified error"
+ }
+ r.Error = x
+ }
return nil
}
diff --git a/src/pkg/rpc/jsonrpc/server.go b/src/pkg/rpc/jsonrpc/server.go
index 9f3472a39..bf53bda8d 100644
--- a/src/pkg/rpc/jsonrpc/server.go
+++ b/src/pkg/rpc/jsonrpc/server.go
@@ -61,7 +61,7 @@ func (r *serverRequest) reset() {
type serverResponse struct {
Id *json.RawMessage "id"
Result interface{} "result"
- Error string "error"
+ Error interface{} "error"
}
func (c *serverCodec) ReadRequestHeader(r *rpc.Request) os.Error {
@@ -94,6 +94,8 @@ func (c *serverCodec) ReadRequestBody(x interface{}) os.Error {
return json.Unmarshal(*c.req.Params, &params)
}
+var null = json.RawMessage([]byte("null"))
+
func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) os.Error {
var resp serverResponse
c.mutex.Lock()
@@ -105,9 +107,17 @@ func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) os.Error {
c.pending[r.Seq] = nil, false
c.mutex.Unlock()
+ if b == nil {
+ // Invalid request so no id. Use JSON null.
+ b = &null
+ }
resp.Id = b
resp.Result = x
- resp.Error = r.Error
+ if r.Error == "" {
+ resp.Error = nil
+ } else {
+ resp.Error = r.Error
+ }
return c.enc.Encode(resp)
}
diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go
index d14f6ded2..5c50bcc3a 100644
--- a/src/pkg/rpc/server.go
+++ b/src/pkg/rpc/server.go
@@ -3,9 +3,9 @@
// license that can be found in the LICENSE file.
/*
- The rpc package provides access to the public methods of an object across a
+ The rpc package provides access to the exported methods of an object across a
network or other I/O connection. A server registers an object, making it visible
- as a service with the name of the type of the object. After registration, public
+ as a service with the name of the type of the object. After registration, exported
methods of the object will be accessible remotely. A server may register multiple
objects (services) of different types but it is an error to register multiple
objects of the same type.
@@ -13,8 +13,8 @@
Only methods that satisfy these criteria will be made available for remote access;
other methods will be ignored:
- - the method receiver and name are publicly visible, that is, begin with an upper case letter.
- - the method has two arguments, both pointers to publicly visible types.
+ - the method receiver and name are exported, that is, begin with an upper case letter.
+ - the method has two arguments, both pointers to exported types.
- the method has return type os.Error.
The method's first argument represents the arguments provided by the caller; the
@@ -123,6 +123,12 @@ import (
"utf8"
)
+const (
+ // Defaults used by HandleHTTP
+ DefaultRPCPath = "/_goRPC_"
+ DefaultDebugPath = "/debug/rpc"
+)
+
// Precompute the reflect type for os.Error. Can't use os.Error directly
// because Typeof takes an empty interface value. This is annoying.
var unusedError *os.Error
@@ -131,8 +137,8 @@ var typeOfOsError = reflect.Typeof(unusedError).(*reflect.PtrType).Elem()
type methodType struct {
sync.Mutex // protects counters
method reflect.Method
- argType *reflect.PtrType
- replyType *reflect.PtrType
+ ArgType *reflect.PtrType
+ ReplyType *reflect.PtrType
numCalls uint
}
@@ -166,23 +172,46 @@ type ClientInfo struct {
RemoteAddr string
}
-type serverType struct {
+// Server represents an RPC Server.
+type Server struct {
sync.Mutex // protects the serviceMap
serviceMap map[string]*service
}
-// This variable is a global whose "public" methods are really private methods
-// called from the global functions of this package: rpc.Register, rpc.ServeConn, etc.
-// For example, rpc.Register() calls server.add().
-var server = &serverType{serviceMap: make(map[string]*service)}
+// NewServer returns a new Server.
+func NewServer() *Server {
+ return &Server{serviceMap: make(map[string]*service)}
+}
-// Is this a publicly visible - upper case - name?
-func isPublic(name string) bool {
+// DefaultServer is the default instance of *Server.
+var DefaultServer = NewServer()
+
+// Is this an exported - upper case - name?
+func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune)
}
-func (server *serverType) register(rcvr interface{}) os.Error {
+// Register publishes in the server the set of methods of the
+// receiver value that satisfy the following conditions:
+// - exported method
+// - two arguments, both pointers to exported structs
+// - one return value, of type os.Error
+// It returns an error if the receiver is not an exported type or has no
+// suitable methods.
+// The client accesses each method using a string of the form "Type.Method",
+// where Type is the receiver's concrete type.
+func (server *Server) Register(rcvr interface{}) os.Error {
+ return server.register(rcvr, "", false)
+}
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func (server *Server) RegisterName(name string, rcvr interface{}) os.Error {
+ return server.register(rcvr, name, true)
+}
+
+func (server *Server) register(rcvr interface{}, name string, useName bool) os.Error {
server.Lock()
defer server.Unlock()
if server.serviceMap == nil {
@@ -192,12 +221,15 @@ func (server *serverType) register(rcvr interface{}) os.Error {
s.typ = reflect.Typeof(rcvr)
s.rcvr = reflect.NewValue(rcvr)
sname := reflect.Indirect(s.rcvr).Type().Name()
+ if useName {
+ sname = name
+ }
if sname == "" {
log.Exit("rpc: no service name for type", s.typ.String())
}
- if s.typ.PkgPath() != "" && !isPublic(sname) {
- s := "rpc Register: type " + sname + " is not public"
- log.Stderr(s)
+ if s.typ.PkgPath() != "" && !isExported(sname) && !useName {
+ s := "rpc Register: type " + sname + " is not exported"
+ log.Print(s)
return os.ErrorString(s)
}
if _, present := server.serviceMap[sname]; present {
@@ -211,54 +243,54 @@ func (server *serverType) register(rcvr interface{}) os.Error {
method := s.typ.Method(m)
mtype := method.Type
mname := method.Name
- if mtype.PkgPath() != "" && !isPublic(mname) {
+ if mtype.PkgPath() != "" || !isExported(mname) {
continue
}
// Method needs three ins: receiver, *args, *reply.
if mtype.NumIn() != 3 {
- log.Stderr("method", mname, "has wrong number of ins:", mtype.NumIn())
+ log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
continue
}
argType, ok := mtype.In(1).(*reflect.PtrType)
if !ok {
- log.Stderr(mname, "arg type not a pointer:", mtype.In(1))
+ log.Println(mname, "arg type not a pointer:", mtype.In(1))
continue
}
replyType, ok := mtype.In(2).(*reflect.PtrType)
if !ok {
- log.Stderr(mname, "reply type not a pointer:", mtype.In(2))
+ log.Println(mname, "reply type not a pointer:", mtype.In(2))
continue
}
- if argType.Elem().PkgPath() != "" && !isPublic(argType.Elem().Name()) {
- log.Stderr(mname, "argument type not public:", argType)
+ if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) {
+ log.Println(mname, "argument type not exported:", argType)
continue
}
- if replyType.Elem().PkgPath() != "" && !isPublic(replyType.Elem().Name()) {
- log.Stderr(mname, "reply type not public:", replyType)
+ if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) {
+ log.Println(mname, "reply type not exported:", replyType)
continue
}
if mtype.NumIn() == 4 {
t := mtype.In(3)
if t != reflect.Typeof((*ClientInfo)(nil)) {
- log.Stderr(mname, "last argument not *ClientInfo")
+ log.Println(mname, "last argument not *ClientInfo")
continue
}
}
// Method needs one out: os.Error.
if mtype.NumOut() != 1 {
- log.Stderr("method", mname, "has wrong number of outs:", mtype.NumOut())
+ log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
continue
}
if returnType := mtype.Out(0); returnType != typeOfOsError {
- log.Stderr("method", mname, "returns", returnType.String(), "not os.Error")
+ log.Println("method", mname, "returns", returnType.String(), "not os.Error")
continue
}
- s.method[mname] = &methodType{method: method, argType: argType, replyType: replyType}
+ s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
}
if len(s.method) == 0 {
- s := "rpc Register: type " + sname + " has no public methods of suitable type"
- log.Stderr(s)
+ s := "rpc Register: type " + sname + " has no exported methods of suitable type"
+ log.Print(s)
return os.ErrorString(s)
}
server.serviceMap[s.name] = s
@@ -289,11 +321,18 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec Se
sending.Lock()
err := codec.WriteResponse(resp, reply)
if err != nil {
- log.Stderr("rpc: writing response: ", err)
+ log.Println("rpc: writing response:", err)
}
sending.Unlock()
}
+func (m *methodType) NumCalls() (n uint) {
+ m.Lock()
+ n = m.numCalls
+ m.Unlock()
+ return n
+}
+
func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
mtype.Lock()
mtype.numCalls++
@@ -335,7 +374,19 @@ func (c *gobServerCodec) Close() os.Error {
return c.rwc.Close()
}
-func (server *serverType) input(codec ServerCodec) {
+
+// ServeConn runs the server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection. To use an alternate codec, use ServeCodec.
+func (server *Server) ServeConn(conn io.ReadWriteCloser) {
+ server.ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex)
for {
// Grab the request header.
@@ -344,7 +395,7 @@ func (server *serverType) input(codec ServerCodec) {
if err != nil {
if err == os.EOF || err == io.ErrUnexpectedEOF {
if err == io.ErrUnexpectedEOF {
- log.Stderr("rpc: ", err)
+ log.Println("rpc:", err)
}
break
}
@@ -374,11 +425,11 @@ func (server *serverType) input(codec ServerCodec) {
continue
}
// Decode the argument value.
- argv := _new(mtype.argType)
- replyv := _new(mtype.replyType)
+ argv := _new(mtype.ArgType)
+ replyv := _new(mtype.ReplyType)
err = codec.ReadRequestBody(argv.Interface())
if err != nil {
- log.Stderr("rpc: tearing down", serviceMethod[0], "connection:", err)
+ log.Println("rpc: tearing down", serviceMethod[0], "connection:", err)
sendResponse(sending, req, replyv.Interface(), codec, err.String())
break
}
@@ -387,24 +438,27 @@ func (server *serverType) input(codec ServerCodec) {
codec.Close()
}
-func (server *serverType) accept(lis net.Listener) {
+// Accept accepts connections on the listener and serves requests
+// for each incoming connection. Accept blocks; the caller typically
+// invokes it in a go statement.
+func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit?
}
- go ServeConn(conn)
+ go server.ServeConn(conn)
}
}
-// Register publishes in the server the set of methods of the
-// receiver value that satisfy the following conditions:
-// - public method
-// - two arguments, both pointers to public structs
-// - one return value of type os.Error
-// It returns an error if the receiver is not public or has no
-// suitable methods.
-func Register(rcvr interface{}) os.Error { return server.register(rcvr) }
+// Register publishes the receiver's methods in the DefaultServer.
+func Register(rcvr interface{}) os.Error { return DefaultServer.Register(rcvr) }
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func RegisterName(name string, rcvr interface{}) os.Error {
+ return DefaultServer.RegisterName(name, rcvr)
+}
// A ServerCodec implements reading of RPC requests and writing of
// RPC responses for the server side of an RPC session.
@@ -420,50 +474,57 @@ type ServerCodec interface {
Close() os.Error
}
-// ServeConn runs the server on a single connection.
+// ServeConn runs the DefaultServer on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
func ServeConn(conn io.ReadWriteCloser) {
- ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
+ DefaultServer.ServeConn(conn)
}
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func ServeCodec(codec ServerCodec) {
- server.input(codec)
+ DefaultServer.ServeCodec(codec)
}
// Accept accepts connections on the listener and serves requests
-// for each incoming connection. Accept blocks; the caller typically
-// invokes it in a go statement.
-func Accept(lis net.Listener) { server.accept(lis) }
+// to DefaultServer for each incoming connection.
+// Accept blocks; the caller typically invokes it in a go statement.
+func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
// Can connect to RPC service using HTTP CONNECT to rpcPath.
-var rpcPath string = "/_goRPC_"
-var debugPath string = "/debug/rpc"
var connected = "200 Connected to Go RPC"
-func serveHTTP(c *http.Conn, req *http.Request) {
+// ServeHTTP implements an http.Handler that answers RPC requests.
+func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
- c.SetHeader("Content-Type", "text/plain; charset=utf-8")
- c.WriteHeader(http.StatusMethodNotAllowed)
- io.WriteString(c, "405 must CONNECT to "+rpcPath+"\n")
+ w.SetHeader("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ io.WriteString(w, "405 must CONNECT\n")
return
}
- conn, _, err := c.Hijack()
+ conn, _, err := w.Hijack()
if err != nil {
- log.Stderr("rpc hijacking ", c.RemoteAddr, ": ", err.String())
+ log.Print("rpc hijacking ", w.RemoteAddr(), ": ", err.String())
return
}
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
- ServeConn(conn)
+ server.ServeConn(conn)
+}
+
+// HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
+// and a debugging handler on debugPath.
+// It is still necessary to invoke http.Serve(), typically in a go statement.
+func (server *Server) HandleHTTP(rpcPath, debugPath string) {
+ http.Handle(rpcPath, server)
+ http.Handle(debugPath, debugHTTP{server})
}
-// HandleHTTP registers an HTTP handler for RPC messages.
+// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
+// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func HandleHTTP() {
- http.Handle(rpcPath, http.HandlerFunc(serveHTTP))
- http.Handle(debugPath, http.HandlerFunc(debugHTTP))
+ DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
}
diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go
index e502db4e3..355d51ce4 100644
--- a/src/pkg/rpc/server_test.go
+++ b/src/pkg/rpc/server_test.go
@@ -9,17 +9,23 @@ import (
"http"
"log"
"net"
- "once"
"os"
"strings"
+ "sync"
"testing"
+ "time"
)
-var serverAddr string
-var httpServerAddr string
-
-const second = 1e9
+var (
+ serverAddr, newServerAddr string
+ httpServerAddr string
+ once, newOnce, httpOnce sync.Once
+)
+const (
+ second = 1e9
+ newHttpPath = "/foo"
+)
type Args struct {
A, B int
@@ -63,32 +69,56 @@ func (t *Arith) Error(args *Args, reply *Reply) os.Error {
panic("ERROR")
}
-func startServer() {
- Register(new(Arith))
-
+func listenTCP() (net.Listener, string) {
l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
if e != nil {
log.Exitf("net.Listen tcp :0: %v", e)
}
- serverAddr = l.Addr().String()
- log.Stderr("Test RPC server listening on ", serverAddr)
+ return l, l.Addr().String()
+}
+
+func startServer() {
+ Register(new(Arith))
+
+ var l net.Listener
+ l, serverAddr = listenTCP()
+ log.Println("Test RPC server listening on", serverAddr)
go Accept(l)
HandleHTTP()
- l, e = net.Listen("tcp", "127.0.0.1:0") // any available address
- if e != nil {
- log.Stderrf("net.Listen tcp :0: %v", e)
- os.Exit(1)
- }
+ httpOnce.Do(startHttpServer)
+}
+
+func startNewServer() {
+ s := NewServer()
+ s.Register(new(Arith))
+
+ var l net.Listener
+ l, newServerAddr = listenTCP()
+ log.Println("NewServer test RPC server listening on", newServerAddr)
+ go Accept(l)
+
+ s.HandleHTTP(newHttpPath, "/bar")
+ httpOnce.Do(startHttpServer)
+}
+
+func startHttpServer() {
+ var l net.Listener
+ l, httpServerAddr = listenTCP()
httpServerAddr = l.Addr().String()
- log.Stderr("Test HTTP RPC server listening on ", httpServerAddr)
+ log.Println("Test HTTP RPC server listening on", httpServerAddr)
go http.Serve(l, nil)
}
func TestRPC(t *testing.T) {
once.Do(startServer)
+ testRPC(t, serverAddr)
+ newOnce.Do(startNewServer)
+ testRPC(t, newServerAddr)
+}
- client, err := Dial("tcp", serverAddr)
+func testRPC(t *testing.T, addr string) {
+ client, err := Dial("tcp", addr)
if err != nil {
t.Fatal("dialing", err)
}
@@ -174,8 +204,19 @@ func TestRPC(t *testing.T) {
func TestHTTPRPC(t *testing.T) {
once.Do(startServer)
+ testHTTPRPC(t, "")
+ newOnce.Do(startNewServer)
+ testHTTPRPC(t, newHttpPath)
+}
- client, err := DialHTTP("tcp", httpServerAddr)
+func testHTTPRPC(t *testing.T, path string) {
+ var client *Client
+ var err os.Error
+ if path == "" {
+ client, err = DialHTTP("tcp", httpServerAddr)
+ } else {
+ client, err = DialHTTPPath("tcp", httpServerAddr, path)
+ }
if err != nil {
t.Fatal("dialing", err)
}
@@ -292,3 +333,52 @@ func TestRegistrationError(t *testing.T) {
t.Errorf("expected error registering ReplyNotPublic")
}
}
+
+type WriteFailCodec int
+
+func (WriteFailCodec) WriteRequest(*Request, interface{}) os.Error {
+ // the panic caused by this error used to not unlock a lock.
+ return os.NewError("fail")
+}
+
+func (WriteFailCodec) ReadResponseHeader(*Response) os.Error {
+ time.Sleep(60e9)
+ panic("unreachable")
+}
+
+func (WriteFailCodec) ReadResponseBody(interface{}) os.Error {
+ time.Sleep(60e9)
+ panic("unreachable")
+}
+
+func (WriteFailCodec) Close() os.Error {
+ return nil
+}
+
+func TestSendDeadlock(t *testing.T) {
+ client := NewClientWithCodec(WriteFailCodec(0))
+
+ done := make(chan bool)
+ go func() {
+ testSendDeadlock(client)
+ testSendDeadlock(client)
+ done <- true
+ }()
+ for i := 0; i < 50; i++ {
+ time.Sleep(100 * 1e6)
+ _, ok := <-done
+ if ok {
+ return
+ }
+ }
+ t.Fatal("deadlock")
+}
+
+func testSendDeadlock(client *Client) {
+ defer func() {
+ recover()
+ }()
+ args := &Args{7, 8}
+ reply := new(Reply)
+ client.Call("Arith.Add", args, reply)
+}