summaryrefslogtreecommitdiff
path: root/src/pkg/net/rpc
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/net/rpc')
-rw-r--r--src/pkg/net/rpc/client.go43
-rw-r--r--src/pkg/net/rpc/jsonrpc/all_test.go61
-rw-r--r--src/pkg/net/rpc/jsonrpc/server.go13
-rw-r--r--src/pkg/net/rpc/server.go84
-rw-r--r--src/pkg/net/rpc/server_test.go72
5 files changed, 193 insertions, 80 deletions
diff --git a/src/pkg/net/rpc/client.go b/src/pkg/net/rpc/client.go
index db2da8e44..4b0c9c3bb 100644
--- a/src/pkg/net/rpc/client.go
+++ b/src/pkg/net/rpc/client.go
@@ -71,7 +71,7 @@ func (client *Client) send(call *Call) {
// Register this call.
client.mutex.Lock()
- if client.shutdown {
+ if client.shutdown || client.closing {
call.Error = ErrShutdown
client.mutex.Unlock()
call.done()
@@ -88,10 +88,13 @@ func (client *Client) send(call *Call) {
err := client.codec.WriteRequest(&client.request, call.Args)
if err != nil {
client.mutex.Lock()
+ call = client.pending[seq]
delete(client.pending, seq)
client.mutex.Unlock()
- call.Error = err
- call.done()
+ if call != nil {
+ call.Error = err
+ call.done()
+ }
}
}
@@ -102,9 +105,6 @@ func (client *Client) input() {
response = Response{}
err = client.codec.ReadResponseHeader(&response)
if err != nil {
- if err == io.EOF && !client.closing {
- err = io.ErrUnexpectedEOF
- }
break
}
seq := response.Seq
@@ -113,12 +113,18 @@ func (client *Client) input() {
delete(client.pending, seq)
client.mutex.Unlock()
- if response.Error == "" {
- err = client.codec.ReadResponseBody(call.Reply)
+ switch {
+ case call == nil:
+ // We've got no pending call. That usually means that
+ // WriteRequest partially failed, and call was already
+ // removed; response is a server telling us about an
+ // error reading request body. We should still attempt
+ // to read error body, but there's no one to give it to.
+ err = client.codec.ReadResponseBody(nil)
if err != nil {
- call.Error = errors.New("reading body " + err.Error())
+ err = errors.New("reading error body: " + err.Error())
}
- } else {
+ case response.Error != "":
// We've got an error response. Give this to the request;
// any subsequent requests will get the ReadResponseBody
// error if there is one.
@@ -127,14 +133,27 @@ func (client *Client) input() {
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
+ call.done()
+ default:
+ err = client.codec.ReadResponseBody(call.Reply)
+ if err != nil {
+ call.Error = errors.New("reading body " + err.Error())
+ }
+ call.done()
}
- call.done()
}
// Terminate pending calls.
client.sending.Lock()
client.mutex.Lock()
client.shutdown = true
closing := client.closing
+ if err == io.EOF {
+ if closing {
+ err = ErrShutdown
+ } else {
+ err = io.ErrUnexpectedEOF
+ }
+ }
for _, call := range client.pending {
call.Error = err
call.done()
@@ -213,7 +232,7 @@ func DialHTTP(network, address string) (*Client, error) {
return DialHTTPPath(network, address, DefaultRPCPath)
}
-// DialHTTPPath connects to an HTTP RPC server
+// DialHTTPPath connects to an HTTP RPC server
// at the specified network address and path.
func DialHTTPPath(network, address, path string) (*Client, error) {
var err error
diff --git a/src/pkg/net/rpc/jsonrpc/all_test.go b/src/pkg/net/rpc/jsonrpc/all_test.go
index e6c7441f0..3c7c4d48f 100644
--- a/src/pkg/net/rpc/jsonrpc/all_test.go
+++ b/src/pkg/net/rpc/jsonrpc/all_test.go
@@ -24,6 +24,12 @@ type Reply struct {
type Arith int
+type ArithAddResp struct {
+ Id interface{} `json:"id"`
+ Result Reply `json:"result"`
+ Error interface{} `json:"error"`
+}
+
func (t *Arith) Add(args *Args, reply *Reply) error {
reply.C = args.A + args.B
return nil
@@ -50,13 +56,39 @@ func init() {
rpc.Register(new(Arith))
}
-func TestServer(t *testing.T) {
- type addResp struct {
- Id interface{} `json:"id"`
- Result Reply `json:"result"`
- Error interface{} `json:"error"`
+func TestServerNoParams(t *testing.T) {
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`)
+ var resp ArithAddResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after no params: %s", err)
+ }
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+
+func TestServerEmptyMessage(t *testing.T) {
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ fmt.Fprintf(cli, "{}")
+ var resp ArithAddResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after empty: %s", err)
}
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+func TestServer(t *testing.T) {
cli, srv := net.Pipe()
defer cli.Close()
go ServeConn(srv)
@@ -65,7 +97,7 @@ func TestServer(t *testing.T) {
// Send hand-coded requests to server, parse responses.
for i := 0; i < 10; i++ {
fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
- var resp addResp
+ var resp ArithAddResp
err := dec.Decode(&resp)
if err != nil {
t.Fatalf("Decode: %s", err)
@@ -80,15 +112,6 @@ func TestServer(t *testing.T) {
t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
}
}
-
- fmt.Fprintf(cli, "{}\n")
- var resp addResp
- if err := dec.Decode(&resp); err != nil {
- t.Fatalf("Decode after empty: %s", err)
- }
- if resp.Error == nil {
- t.Fatalf("Expected error, got nil")
- }
}
func TestClient(t *testing.T) {
@@ -108,7 +131,7 @@ func TestClient(t *testing.T) {
t.Errorf("Add: expected no error but got string %q", err.Error())
}
if reply.C != args.A+args.B {
- t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B)
}
args = &Args{7, 8}
@@ -118,7 +141,7 @@ func TestClient(t *testing.T) {
t.Errorf("Mul: expected no error but got string %q", err.Error())
}
if reply.C != args.A*args.B {
- t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B)
}
// Out of order.
@@ -133,7 +156,7 @@ func TestClient(t *testing.T) {
t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
}
if addReply.C != args.A+args.B {
- t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+ t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B)
}
mulCall = <-mulCall.Done
@@ -141,7 +164,7 @@ func TestClient(t *testing.T) {
t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
}
if mulReply.C != args.A*args.B {
- t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+ t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B)
}
// Error test
diff --git a/src/pkg/net/rpc/jsonrpc/server.go b/src/pkg/net/rpc/jsonrpc/server.go
index 4c54553a7..5bc05fd0a 100644
--- a/src/pkg/net/rpc/jsonrpc/server.go
+++ b/src/pkg/net/rpc/jsonrpc/server.go
@@ -12,6 +12,8 @@ import (
"sync"
)
+var errMissingParams = errors.New("jsonrpc: request body missing params")
+
type serverCodec struct {
dec *json.Decoder // for reading JSON values
enc *json.Encoder // for writing JSON values
@@ -50,12 +52,8 @@ type serverRequest struct {
func (r *serverRequest) reset() {
r.Method = ""
- if r.Params != nil {
- *r.Params = (*r.Params)[0:0]
- }
- if r.Id != nil {
- *r.Id = (*r.Id)[0:0]
- }
+ r.Params = nil
+ r.Id = nil
}
type serverResponse struct {
@@ -88,6 +86,9 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error {
if x == nil {
return nil
}
+ if c.req.Params == nil {
+ return errMissingParams
+ }
// JSON params is array value.
// RPC params is struct.
// Unmarshal into array containing struct for now.
diff --git a/src/pkg/net/rpc/server.go b/src/pkg/net/rpc/server.go
index 1680e2f0d..e71b6fb1a 100644
--- a/src/pkg/net/rpc/server.go
+++ b/src/pkg/net/rpc/server.go
@@ -24,12 +24,13 @@
where T, T1 and T2 can be marshaled by encoding/gob.
These requirements apply even if a different codec is used.
- (In future, these requirements may soften for custom codecs.)
+ (In the future, these requirements may soften for custom codecs.)
The method's first argument represents the arguments provided by the caller; the
second argument represents the result parameters to be returned to the caller.
The method's return value, if non-nil, is passed back as a string that the client
- sees as if created by errors.New.
+ sees as if created by errors.New. If an error is returned, the reply parameter
+ will not be sent back to the client.
The server may handle requests on a single connection by calling ServeConn. More
typically it will create a network listener and call Accept or, for an HTTP
@@ -111,7 +112,7 @@
// Asynchronous call
quotient := new(Quotient)
- divCall := client.Go("Arith.Divide", args, &quotient, nil)
+ divCall := client.Go("Arith.Divide", args, quotient, nil)
replyCall := <-divCall.Done // will be equal to divCall
// check errors, print, etc.
@@ -181,7 +182,7 @@ type Response struct {
// Server represents an RPC Server.
type Server struct {
- mu sync.Mutex // protects the serviceMap
+ mu sync.RWMutex // protects the serviceMap
serviceMap map[string]*service
reqLock sync.Mutex // protects freeReq
freeReq *Request
@@ -218,15 +219,15 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
// - exported method
// - two arguments, both pointers to exported structs
// - one return value, of type error
-// It returns an error if the receiver is not an exported type or has no
-// suitable methods.
+// It returns an error if the receiver is not an exported type or has
+// no methods or unsuitable methods. It also logs the error using package log.
// The client accesses each method using a string of the form "Type.Method",
// where Type is the receiver's concrete type.
func (server *Server) Register(rcvr interface{}) error {
return server.register(rcvr, "", false)
}
-// RegisterName is like Register but uses the provided name for the type
+// RegisterName is like Register but uses the provided name for the type
// instead of the receiver's concrete type.
func (server *Server) RegisterName(name string, rcvr interface{}) error {
return server.register(rcvr, name, true)
@@ -260,8 +261,30 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
s.method = make(map[string]*methodType)
// Install the methods
- for m := 0; m < s.typ.NumMethod(); m++ {
- method := s.typ.Method(m)
+ s.method = suitableMethods(s.typ, true)
+
+ if len(s.method) == 0 {
+ str := ""
+ // To help the user, see if a pointer receiver would work.
+ method := suitableMethods(reflect.PtrTo(s.typ), false)
+ if len(method) != 0 {
+ str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
+ } else {
+ str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
+ }
+ log.Print(str)
+ return errors.New(str)
+ }
+ server.serviceMap[s.name] = s
+ return nil
+}
+
+// suitableMethods returns suitable Rpc methods of typ, it will report
+// error using log if reportErr is true.
+func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
+ methods := make(map[string]*methodType)
+ for m := 0; m < typ.NumMethod(); m++ {
+ method := typ.Method(m)
mtype := method.Type
mname := method.Name
// Method must be exported.
@@ -270,46 +293,51 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
}
// Method needs three ins: receiver, *args, *reply.
if mtype.NumIn() != 3 {
- log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
+ if reportErr {
+ log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
+ }
continue
}
// First arg need not be a pointer.
argType := mtype.In(1)
if !isExportedOrBuiltinType(argType) {
- log.Println(mname, "argument type not exported:", argType)
+ if reportErr {
+ log.Println(mname, "argument type not exported:", argType)
+ }
continue
}
// Second arg must be a pointer.
replyType := mtype.In(2)
if replyType.Kind() != reflect.Ptr {
- log.Println("method", mname, "reply type not a pointer:", replyType)
+ if reportErr {
+ log.Println("method", mname, "reply type not a pointer:", replyType)
+ }
continue
}
// Reply type must be exported.
if !isExportedOrBuiltinType(replyType) {
- log.Println("method", mname, "reply type not exported:", replyType)
+ if reportErr {
+ log.Println("method", mname, "reply type not exported:", replyType)
+ }
continue
}
// Method needs one out.
if mtype.NumOut() != 1 {
- log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
+ if reportErr {
+ log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
+ }
continue
}
// The return type of the method must be error.
if returnType := mtype.Out(0); returnType != typeOfError {
- log.Println("method", mname, "returns", returnType.String(), "not error")
+ if reportErr {
+ log.Println("method", mname, "returns", returnType.String(), "not error")
+ }
continue
}
- s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
+ methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
}
-
- if len(s.method) == 0 {
- s := "rpc Register: type " + sname + " has no exported methods of suitable type"
- log.Print(s)
- return errors.New(s)
- }
- server.serviceMap[s.name] = s
- return nil
+ return methods
}
// A value sent as a placeholder for the server's response value when the server
@@ -538,9 +566,9 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt
return
}
// Look up the request.
- server.mu.Lock()
+ server.mu.RLock()
service = server.serviceMap[serviceMethod[0]]
- server.mu.Unlock()
+ server.mu.RUnlock()
if service == nil {
err = errors.New("rpc: can't find service " + req.ServiceMethod)
return
@@ -568,7 +596,7 @@ func (server *Server) Accept(lis net.Listener) {
// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
-// RegisterName is like Register but uses the provided name for the type
+// RegisterName is like Register but uses the provided name for the type
// instead of the receiver's concrete type.
func RegisterName(name string, rcvr interface{}) error {
return DefaultServer.RegisterName(name, rcvr)
@@ -611,7 +639,7 @@ func ServeRequest(codec ServerCodec) error {
}
// Accept accepts connections on the listener and serves requests
-// to DefaultServer for each incoming connection.
+// to DefaultServer for each incoming connection.
// Accept blocks; the caller typically invokes it in a go statement.
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
diff --git a/src/pkg/net/rpc/server_test.go b/src/pkg/net/rpc/server_test.go
index 62c7b1e60..8a1530623 100644
--- a/src/pkg/net/rpc/server_test.go
+++ b/src/pkg/net/rpc/server_test.go
@@ -349,6 +349,7 @@ func testServeRequest(t *testing.T, server *Server) {
type ReplyNotPointer int
type ArgNotPublic int
type ReplyNotPublic int
+type NeedsPtrType int
type local struct{}
func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
@@ -363,19 +364,29 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
return nil
}
+func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
+ return nil
+}
+
// Check that registration handles lots of bad methods and a type with no suitable methods.
func TestRegistrationError(t *testing.T) {
err := Register(new(ReplyNotPointer))
if err == nil {
- t.Errorf("expected error registering ReplyNotPointer")
+ t.Error("expected error registering ReplyNotPointer")
}
err = Register(new(ArgNotPublic))
if err == nil {
- t.Errorf("expected error registering ArgNotPublic")
+ t.Error("expected error registering ArgNotPublic")
}
err = Register(new(ReplyNotPublic))
if err == nil {
- t.Errorf("expected error registering ReplyNotPublic")
+ t.Error("expected error registering ReplyNotPublic")
+ }
+ err = Register(NeedsPtrType(0))
+ if err == nil {
+ t.Error("expected error registering NeedsPtrType")
+ } else if !strings.Contains(err.Error(), "pointer") {
+ t.Error("expected hint when registering NeedsPtrType")
}
}
@@ -434,7 +445,7 @@ func dialHTTP() (*Client, error) {
return DialHTTP("tcp", httpServerAddr)
}
-func countMallocs(dial func() (*Client, error), t *testing.T) uint64 {
+func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
once.Do(startServer)
client, err := dial()
if err != nil {
@@ -442,11 +453,7 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 {
}
args := &Args{7, 8}
reply := new(Reply)
- memstats := new(runtime.MemStats)
- runtime.ReadMemStats(memstats)
- mallocs := 0 - memstats.Mallocs
- const count = 100
- for i := 0; i < count; i++ {
+ return testing.AllocsPerRun(100, func() {
err := client.Call("Arith.Add", args, reply)
if err != nil {
t.Errorf("Add: expected no error but got string %q", err.Error())
@@ -454,18 +461,15 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 {
if reply.C != args.A+args.B {
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
}
- }
- runtime.ReadMemStats(memstats)
- mallocs += memstats.Mallocs
- return mallocs / count
+ })
}
func TestCountMallocs(t *testing.T) {
- fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t))
+ fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
}
func TestCountMallocsOverHTTP(t *testing.T) {
- fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t))
+ fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
}
type writeCrasher struct {
@@ -499,6 +503,44 @@ func TestClientWriteError(t *testing.T) {
w.done <- true
}
+func TestTCPClose(t *testing.T) {
+ once.Do(startServer)
+
+ client, err := dialHTTP()
+ if err != nil {
+ t.Fatalf("dialing: %v", err)
+ }
+ defer client.Close()
+
+ args := Args{17, 8}
+ var reply Reply
+ err = client.Call("Arith.Mul", args, &reply)
+ if err != nil {
+ t.Fatal("arith error:", err)
+ }
+ t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
+ if reply.C != args.A*args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
+ }
+}
+
+func TestErrorAfterClientClose(t *testing.T) {
+ once.Do(startServer)
+
+ client, err := dialHTTP()
+ if err != nil {
+ t.Fatalf("dialing: %v", err)
+ }
+ err = client.Close()
+ if err != nil {
+ t.Fatal("close error:", err)
+ }
+ err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
+ if err != ErrShutdown {
+ t.Errorf("Forever: expected ErrShutdown got %v", err)
+ }
+}
+
func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
b.StopTimer()
once.Do(startServer)