diff options
Diffstat (limited to 'src/pkg/net/rpc')
-rw-r--r-- | src/pkg/net/rpc/client.go | 43 | ||||
-rw-r--r-- | src/pkg/net/rpc/jsonrpc/all_test.go | 61 | ||||
-rw-r--r-- | src/pkg/net/rpc/jsonrpc/server.go | 13 | ||||
-rw-r--r-- | src/pkg/net/rpc/server.go | 84 | ||||
-rw-r--r-- | src/pkg/net/rpc/server_test.go | 72 |
5 files changed, 193 insertions, 80 deletions
diff --git a/src/pkg/net/rpc/client.go b/src/pkg/net/rpc/client.go index db2da8e44..4b0c9c3bb 100644 --- a/src/pkg/net/rpc/client.go +++ b/src/pkg/net/rpc/client.go @@ -71,7 +71,7 @@ func (client *Client) send(call *Call) { // Register this call. client.mutex.Lock() - if client.shutdown { + if client.shutdown || client.closing { call.Error = ErrShutdown client.mutex.Unlock() call.done() @@ -88,10 +88,13 @@ func (client *Client) send(call *Call) { err := client.codec.WriteRequest(&client.request, call.Args) if err != nil { client.mutex.Lock() + call = client.pending[seq] delete(client.pending, seq) client.mutex.Unlock() - call.Error = err - call.done() + if call != nil { + call.Error = err + call.done() + } } } @@ -102,9 +105,6 @@ func (client *Client) input() { response = Response{} err = client.codec.ReadResponseHeader(&response) if err != nil { - if err == io.EOF && !client.closing { - err = io.ErrUnexpectedEOF - } break } seq := response.Seq @@ -113,12 +113,18 @@ func (client *Client) input() { delete(client.pending, seq) client.mutex.Unlock() - if response.Error == "" { - err = client.codec.ReadResponseBody(call.Reply) + switch { + case call == nil: + // We've got no pending call. That usually means that + // WriteRequest partially failed, and call was already + // removed; response is a server telling us about an + // error reading request body. We should still attempt + // to read error body, but there's no one to give it to. + err = client.codec.ReadResponseBody(nil) if err != nil { - call.Error = errors.New("reading body " + err.Error()) + err = errors.New("reading error body: " + err.Error()) } - } else { + case response.Error != "": // We've got an error response. Give this to the request; // any subsequent requests will get the ReadResponseBody // error if there is one. @@ -127,14 +133,27 @@ func (client *Client) input() { if err != nil { err = errors.New("reading error body: " + err.Error()) } + call.done() + default: + err = client.codec.ReadResponseBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() } - call.done() } // Terminate pending calls. client.sending.Lock() client.mutex.Lock() client.shutdown = true closing := client.closing + if err == io.EOF { + if closing { + err = ErrShutdown + } else { + err = io.ErrUnexpectedEOF + } + } for _, call := range client.pending { call.Error = err call.done() @@ -213,7 +232,7 @@ func DialHTTP(network, address string) (*Client, error) { return DialHTTPPath(network, address, DefaultRPCPath) } -// DialHTTPPath connects to an HTTP RPC server +// DialHTTPPath connects to an HTTP RPC server // at the specified network address and path. func DialHTTPPath(network, address, path string) (*Client, error) { var err error diff --git a/src/pkg/net/rpc/jsonrpc/all_test.go b/src/pkg/net/rpc/jsonrpc/all_test.go index e6c7441f0..3c7c4d48f 100644 --- a/src/pkg/net/rpc/jsonrpc/all_test.go +++ b/src/pkg/net/rpc/jsonrpc/all_test.go @@ -24,6 +24,12 @@ type Reply struct { type Arith int +type ArithAddResp struct { + Id interface{} `json:"id"` + Result Reply `json:"result"` + Error interface{} `json:"error"` +} + func (t *Arith) Add(args *Args, reply *Reply) error { reply.C = args.A + args.B return nil @@ -50,13 +56,39 @@ func init() { rpc.Register(new(Arith)) } -func TestServer(t *testing.T) { - type addResp struct { - Id interface{} `json:"id"` - Result Reply `json:"result"` - Error interface{} `json:"error"` +func TestServerNoParams(t *testing.T) { + cli, srv := net.Pipe() + defer cli.Close() + go ServeConn(srv) + dec := json.NewDecoder(cli) + + fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`) + var resp ArithAddResp + if err := dec.Decode(&resp); err != nil { + t.Fatalf("Decode after no params: %s", err) + } + if resp.Error == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestServerEmptyMessage(t *testing.T) { + cli, srv := net.Pipe() + defer cli.Close() + go ServeConn(srv) + dec := json.NewDecoder(cli) + + fmt.Fprintf(cli, "{}") + var resp ArithAddResp + if err := dec.Decode(&resp); err != nil { + t.Fatalf("Decode after empty: %s", err) } + if resp.Error == nil { + t.Fatalf("Expected error, got nil") + } +} +func TestServer(t *testing.T) { cli, srv := net.Pipe() defer cli.Close() go ServeConn(srv) @@ -65,7 +97,7 @@ func TestServer(t *testing.T) { // Send hand-coded requests to server, parse responses. for i := 0; i < 10; i++ { fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1) - var resp addResp + var resp ArithAddResp err := dec.Decode(&resp) if err != nil { t.Fatalf("Decode: %s", err) @@ -80,15 +112,6 @@ func TestServer(t *testing.T) { t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C) } } - - fmt.Fprintf(cli, "{}\n") - var resp addResp - if err := dec.Decode(&resp); err != nil { - t.Fatalf("Decode after empty: %s", err) - } - if resp.Error == nil { - t.Fatalf("Expected error, got nil") - } } func TestClient(t *testing.T) { @@ -108,7 +131,7 @@ func TestClient(t *testing.T) { t.Errorf("Add: expected no error but got string %q", err.Error()) } if reply.C != args.A+args.B { - t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B) } args = &Args{7, 8} @@ -118,7 +141,7 @@ func TestClient(t *testing.T) { t.Errorf("Mul: expected no error but got string %q", err.Error()) } if reply.C != args.A*args.B { - t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B) } // Out of order. @@ -133,7 +156,7 @@ func TestClient(t *testing.T) { t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) } if addReply.C != args.A+args.B { - t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) + t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B) } mulCall = <-mulCall.Done @@ -141,7 +164,7 @@ func TestClient(t *testing.T) { t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) } if mulReply.C != args.A*args.B { - t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) + t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B) } // Error test diff --git a/src/pkg/net/rpc/jsonrpc/server.go b/src/pkg/net/rpc/jsonrpc/server.go index 4c54553a7..5bc05fd0a 100644 --- a/src/pkg/net/rpc/jsonrpc/server.go +++ b/src/pkg/net/rpc/jsonrpc/server.go @@ -12,6 +12,8 @@ import ( "sync" ) +var errMissingParams = errors.New("jsonrpc: request body missing params") + type serverCodec struct { dec *json.Decoder // for reading JSON values enc *json.Encoder // for writing JSON values @@ -50,12 +52,8 @@ type serverRequest struct { func (r *serverRequest) reset() { r.Method = "" - if r.Params != nil { - *r.Params = (*r.Params)[0:0] - } - if r.Id != nil { - *r.Id = (*r.Id)[0:0] - } + r.Params = nil + r.Id = nil } type serverResponse struct { @@ -88,6 +86,9 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error { if x == nil { return nil } + if c.req.Params == nil { + return errMissingParams + } // JSON params is array value. // RPC params is struct. // Unmarshal into array containing struct for now. diff --git a/src/pkg/net/rpc/server.go b/src/pkg/net/rpc/server.go index 1680e2f0d..e71b6fb1a 100644 --- a/src/pkg/net/rpc/server.go +++ b/src/pkg/net/rpc/server.go @@ -24,12 +24,13 @@ where T, T1 and T2 can be marshaled by encoding/gob. These requirements apply even if a different codec is used. - (In future, these requirements may soften for custom codecs.) + (In the future, these requirements may soften for custom codecs.) The method's first argument represents the arguments provided by the caller; the second argument represents the result parameters to be returned to the caller. The method's return value, if non-nil, is passed back as a string that the client - sees as if created by errors.New. + sees as if created by errors.New. If an error is returned, the reply parameter + will not be sent back to the client. The server may handle requests on a single connection by calling ServeConn. More typically it will create a network listener and call Accept or, for an HTTP @@ -111,7 +112,7 @@ // Asynchronous call quotient := new(Quotient) - divCall := client.Go("Arith.Divide", args, "ient, nil) + divCall := client.Go("Arith.Divide", args, quotient, nil) replyCall := <-divCall.Done // will be equal to divCall // check errors, print, etc. @@ -181,7 +182,7 @@ type Response struct { // Server represents an RPC Server. type Server struct { - mu sync.Mutex // protects the serviceMap + mu sync.RWMutex // protects the serviceMap serviceMap map[string]*service reqLock sync.Mutex // protects freeReq freeReq *Request @@ -218,15 +219,15 @@ func isExportedOrBuiltinType(t reflect.Type) bool { // - exported method // - two arguments, both pointers to exported structs // - one return value, of type error -// It returns an error if the receiver is not an exported type or has no -// suitable methods. +// It returns an error if the receiver is not an exported type or has +// no methods or unsuitable methods. It also logs the error using package log. // The client accesses each method using a string of the form "Type.Method", // where Type is the receiver's concrete type. func (server *Server) Register(rcvr interface{}) error { return server.register(rcvr, "", false) } -// RegisterName is like Register but uses the provided name for the type +// RegisterName is like Register but uses the provided name for the type // instead of the receiver's concrete type. func (server *Server) RegisterName(name string, rcvr interface{}) error { return server.register(rcvr, name, true) @@ -260,8 +261,30 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro s.method = make(map[string]*methodType) // Install the methods - for m := 0; m < s.typ.NumMethod(); m++ { - method := s.typ.Method(m) + s.method = suitableMethods(s.typ, true) + + if len(s.method) == 0 { + str := "" + // To help the user, see if a pointer receiver would work. + method := suitableMethods(reflect.PtrTo(s.typ), false) + if len(method) != 0 { + str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)" + } else { + str = "rpc.Register: type " + sname + " has no exported methods of suitable type" + } + log.Print(str) + return errors.New(str) + } + server.serviceMap[s.name] = s + return nil +} + +// suitableMethods returns suitable Rpc methods of typ, it will report +// error using log if reportErr is true. +func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { + methods := make(map[string]*methodType) + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) mtype := method.Type mname := method.Name // Method must be exported. @@ -270,46 +293,51 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro } // Method needs three ins: receiver, *args, *reply. if mtype.NumIn() != 3 { - log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) + if reportErr { + log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) + } continue } // First arg need not be a pointer. argType := mtype.In(1) if !isExportedOrBuiltinType(argType) { - log.Println(mname, "argument type not exported:", argType) + if reportErr { + log.Println(mname, "argument type not exported:", argType) + } continue } // Second arg must be a pointer. replyType := mtype.In(2) if replyType.Kind() != reflect.Ptr { - log.Println("method", mname, "reply type not a pointer:", replyType) + if reportErr { + log.Println("method", mname, "reply type not a pointer:", replyType) + } continue } // Reply type must be exported. if !isExportedOrBuiltinType(replyType) { - log.Println("method", mname, "reply type not exported:", replyType) + if reportErr { + log.Println("method", mname, "reply type not exported:", replyType) + } continue } // Method needs one out. if mtype.NumOut() != 1 { - log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) + if reportErr { + log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) + } continue } // The return type of the method must be error. if returnType := mtype.Out(0); returnType != typeOfError { - log.Println("method", mname, "returns", returnType.String(), "not error") + if reportErr { + log.Println("method", mname, "returns", returnType.String(), "not error") + } continue } - s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} + methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} } - - if len(s.method) == 0 { - s := "rpc Register: type " + sname + " has no exported methods of suitable type" - log.Print(s) - return errors.New(s) - } - server.serviceMap[s.name] = s - return nil + return methods } // A value sent as a placeholder for the server's response value when the server @@ -538,9 +566,9 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt return } // Look up the request. - server.mu.Lock() + server.mu.RLock() service = server.serviceMap[serviceMethod[0]] - server.mu.Unlock() + server.mu.RUnlock() if service == nil { err = errors.New("rpc: can't find service " + req.ServiceMethod) return @@ -568,7 +596,7 @@ func (server *Server) Accept(lis net.Listener) { // Register publishes the receiver's methods in the DefaultServer. func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } -// RegisterName is like Register but uses the provided name for the type +// RegisterName is like Register but uses the provided name for the type // instead of the receiver's concrete type. func RegisterName(name string, rcvr interface{}) error { return DefaultServer.RegisterName(name, rcvr) @@ -611,7 +639,7 @@ func ServeRequest(codec ServerCodec) error { } // Accept accepts connections on the listener and serves requests -// to DefaultServer for each incoming connection. +// to DefaultServer for each incoming connection. // Accept blocks; the caller typically invokes it in a go statement. func Accept(lis net.Listener) { DefaultServer.Accept(lis) } diff --git a/src/pkg/net/rpc/server_test.go b/src/pkg/net/rpc/server_test.go index 62c7b1e60..8a1530623 100644 --- a/src/pkg/net/rpc/server_test.go +++ b/src/pkg/net/rpc/server_test.go @@ -349,6 +349,7 @@ func testServeRequest(t *testing.T, server *Server) { type ReplyNotPointer int type ArgNotPublic int type ReplyNotPublic int +type NeedsPtrType int type local struct{} func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error { @@ -363,19 +364,29 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error { return nil } +func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error { + return nil +} + // Check that registration handles lots of bad methods and a type with no suitable methods. func TestRegistrationError(t *testing.T) { err := Register(new(ReplyNotPointer)) if err == nil { - t.Errorf("expected error registering ReplyNotPointer") + t.Error("expected error registering ReplyNotPointer") } err = Register(new(ArgNotPublic)) if err == nil { - t.Errorf("expected error registering ArgNotPublic") + t.Error("expected error registering ArgNotPublic") } err = Register(new(ReplyNotPublic)) if err == nil { - t.Errorf("expected error registering ReplyNotPublic") + t.Error("expected error registering ReplyNotPublic") + } + err = Register(NeedsPtrType(0)) + if err == nil { + t.Error("expected error registering NeedsPtrType") + } else if !strings.Contains(err.Error(), "pointer") { + t.Error("expected hint when registering NeedsPtrType") } } @@ -434,7 +445,7 @@ func dialHTTP() (*Client, error) { return DialHTTP("tcp", httpServerAddr) } -func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { +func countMallocs(dial func() (*Client, error), t *testing.T) float64 { once.Do(startServer) client, err := dial() if err != nil { @@ -442,11 +453,7 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { } args := &Args{7, 8} reply := new(Reply) - memstats := new(runtime.MemStats) - runtime.ReadMemStats(memstats) - mallocs := 0 - memstats.Mallocs - const count = 100 - for i := 0; i < count; i++ { + return testing.AllocsPerRun(100, func() { err := client.Call("Arith.Add", args, reply) if err != nil { t.Errorf("Add: expected no error but got string %q", err.Error()) @@ -454,18 +461,15 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { if reply.C != args.A+args.B { t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) } - } - runtime.ReadMemStats(memstats) - mallocs += memstats.Mallocs - return mallocs / count + }) } func TestCountMallocs(t *testing.T) { - fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t)) + fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) } func TestCountMallocsOverHTTP(t *testing.T) { - fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t)) + fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) } type writeCrasher struct { @@ -499,6 +503,44 @@ func TestClientWriteError(t *testing.T) { w.done <- true } +func TestTCPClose(t *testing.T) { + once.Do(startServer) + + client, err := dialHTTP() + if err != nil { + t.Fatalf("dialing: %v", err) + } + defer client.Close() + + args := Args{17, 8} + var reply Reply + err = client.Call("Arith.Mul", args, &reply) + if err != nil { + t.Fatal("arith error:", err) + } + t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) + if reply.C != args.A*args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) + } +} + +func TestErrorAfterClientClose(t *testing.T) { + once.Do(startServer) + + client, err := dialHTTP() + if err != nil { + t.Fatalf("dialing: %v", err) + } + err = client.Close() + if err != nil { + t.Fatal("close error:", err) + } + err = client.Call("Arith.Add", &Args{7, 9}, new(Reply)) + if err != ErrShutdown { + t.Errorf("Forever: expected ErrShutdown got %v", err) + } +} + func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { b.StopTimer() once.Do(startServer) |