diff options
Diffstat (limited to 'src/pkg/rpc')
-rw-r--r-- | src/pkg/rpc/client.go | 58 | ||||
-rw-r--r-- | src/pkg/rpc/debug.go | 2 | ||||
-rw-r--r-- | src/pkg/rpc/jsonrpc/client.go | 3 | ||||
-rw-r--r-- | src/pkg/rpc/jsonrpc/server.go | 3 | ||||
-rw-r--r-- | src/pkg/rpc/server.go | 90 | ||||
-rw-r--r-- | src/pkg/rpc/server_test.go | 103 |
6 files changed, 139 insertions, 120 deletions
diff --git a/src/pkg/rpc/client.go b/src/pkg/rpc/client.go index 6f028c10d..6de6d1325 100644 --- a/src/pkg/rpc/client.go +++ b/src/pkg/rpc/client.go @@ -15,6 +15,16 @@ import ( "sync" ) +// ServerError represents an error that has been returned from +// the remote side of the RPC connection. +type ServerError string + +func (e ServerError) String() string { + return string(e) +} + +const ErrShutdown = os.ErrorString("connection is shut down") + // Call represents an active RPC. type Call struct { ServiceMethod string // The name of the service and method to call. @@ -30,12 +40,12 @@ type Call struct { // with a single Client. type Client struct { mutex sync.Mutex // protects pending, seq - shutdown os.Error // non-nil if the client is shut down sending sync.Mutex seq uint64 codec ClientCodec pending map[uint64]*Call closing bool + shutdown bool } // A ClientCodec implements writing of RPC requests and @@ -43,7 +53,9 @@ type Client struct { // The client calls WriteRequest to write a request to the connection // and calls ReadResponseHeader and ReadResponseBody in pairs // to read responses. The client calls Close when finished with the -// connection. +// connection. ReadResponseBody may be called with a nil +// argument to force the body of the response to be read and then +// discarded. type ClientCodec interface { WriteRequest(*Request, interface{}) os.Error ReadResponseHeader(*Response) os.Error @@ -55,8 +67,8 @@ type ClientCodec interface { func (client *Client) send(c *Call) { // Register this call. client.mutex.Lock() - if client.shutdown != nil { - c.Error = client.shutdown + if client.shutdown { + c.Error = ErrShutdown client.mutex.Unlock() c.done() return @@ -93,20 +105,27 @@ func (client *Client) input() { c := client.pending[seq] client.pending[seq] = c, false client.mutex.Unlock() - err = client.codec.ReadResponseBody(c.Reply) - if response.Error != "" { - c.Error = os.ErrorString(response.Error) - } else if err != nil { - c.Error = err + + if response.Error == "" { + err = client.codec.ReadResponseBody(c.Reply) + if err != nil { + c.Error = os.ErrorString("reading body " + err.String()) + } } else { - // Empty strings should turn into nil os.Errors - c.Error = nil + // We've got an error response. Give this to the request; + // any subsequent requests will get the ReadResponseBody + // error if there is one. + c.Error = ServerError(response.Error) + err = client.codec.ReadResponseBody(nil) + if err != nil { + err = os.ErrorString("reading error body: " + err.String()) + } } c.done() } // Terminate pending calls. client.mutex.Lock() - client.shutdown = err + client.shutdown = true for _, call := range client.pending { call.Error = err call.done() @@ -209,10 +228,11 @@ func Dial(network, address string) (*Client, os.Error) { } func (client *Client) Close() os.Error { - if client.shutdown != nil || client.closing { - return os.ErrorString("rpc: already closed") - } client.mutex.Lock() + if client.shutdown || client.closing { + client.mutex.Unlock() + return ErrShutdown + } client.closing = true client.mutex.Unlock() return client.codec.Close() @@ -239,8 +259,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface } } c.Done = done - if client.shutdown != nil { - c.Error = client.shutdown + if client.shutdown { + c.Error = ErrShutdown c.done() return c } @@ -250,8 +270,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface // Call invokes the named function, waits for it to complete, and returns its error status. func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error { - if client.shutdown != nil { - return client.shutdown + if client.shutdown { + return ErrShutdown } call := <-client.Go(serviceMethod, args, reply, nil).Done return call.Error diff --git a/src/pkg/rpc/debug.go b/src/pkg/rpc/debug.go index 44b32e04b..32dc8a18b 100644 --- a/src/pkg/rpc/debug.go +++ b/src/pkg/rpc/debug.go @@ -83,7 +83,7 @@ func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { } server.Unlock() sort.Sort(services) - err := debug.Execute(services, w) + err := debug.Execute(w, services) if err != nil { fmt.Fprintln(w, "rpc: error executing template:", err.String()) } diff --git a/src/pkg/rpc/jsonrpc/client.go b/src/pkg/rpc/jsonrpc/client.go index dcaa69f9d..5b806bd6e 100644 --- a/src/pkg/rpc/jsonrpc/client.go +++ b/src/pkg/rpc/jsonrpc/client.go @@ -98,6 +98,9 @@ func (c *clientCodec) ReadResponseHeader(r *rpc.Response) os.Error { } func (c *clientCodec) ReadResponseBody(x interface{}) os.Error { + if x == nil { + return nil + } return json.Unmarshal(*c.resp.Result, x) } diff --git a/src/pkg/rpc/jsonrpc/server.go b/src/pkg/rpc/jsonrpc/server.go index bf53bda8d..9c6b8b40d 100644 --- a/src/pkg/rpc/jsonrpc/server.go +++ b/src/pkg/rpc/jsonrpc/server.go @@ -85,6 +85,9 @@ func (c *serverCodec) ReadRequestHeader(r *rpc.Request) os.Error { } func (c *serverCodec) ReadRequestBody(x interface{}) os.Error { + if x == nil { + return nil + } // JSON params is array value. // RPC params is struct. // Unmarshal into array containing struct for now. diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index 91e9cd5c8..9dcda4148 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -299,10 +299,10 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E // A value sent as a placeholder for the response when the server receives an invalid request. type InvalidRequest struct { - marker int + Marker int } -var invalidRequest = InvalidRequest{1} +var invalidRequest = InvalidRequest{} func _new(t *reflect.PtrType) *reflect.PtrValue { v := reflect.MakeZero(t).(*reflect.PtrValue) @@ -316,6 +316,7 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec Se resp.ServiceMethod = req.ServiceMethod if errmsg != "" { resp.Error = errmsg + reply = invalidRequest } resp.Seq = req.Seq sending.Lock() @@ -389,54 +390,74 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) { func (server *Server) ServeCodec(codec ServerCodec) { sending := new(sync.Mutex) for { - // Grab the request header. - req := new(Request) - err := codec.ReadRequestHeader(req) + req, service, mtype, err := server.readRequest(codec) if err != nil { + if err != os.EOF { + log.Println("rpc:", err) + } if err == os.EOF || err == io.ErrUnexpectedEOF { - if err == io.ErrUnexpectedEOF { - log.Println("rpc:", err) - } break } - s := "rpc: server cannot decode request: " + err.String() - sendResponse(sending, req, invalidRequest, codec, s) - break - } - serviceMethod := strings.Split(req.ServiceMethod, ".", -1) - if len(serviceMethod) != 2 { - s := "rpc: service/method request ill-formed: " + req.ServiceMethod - sendResponse(sending, req, invalidRequest, codec, s) - continue - } - // Look up the request. - server.Lock() - service, ok := server.serviceMap[serviceMethod[0]] - server.Unlock() - if !ok { - s := "rpc: can't find service " + req.ServiceMethod - sendResponse(sending, req, invalidRequest, codec, s) - continue - } - mtype, ok := service.method[serviceMethod[1]] - if !ok { - s := "rpc: can't find method " + req.ServiceMethod - sendResponse(sending, req, invalidRequest, codec, s) + // discard body + codec.ReadRequestBody(nil) + + // send a response if we actually managed to read a header. + if req != nil { + sendResponse(sending, req, invalidRequest, codec, err.String()) + } continue } + // Decode the argument value. argv := _new(mtype.ArgType) replyv := _new(mtype.ReplyType) err = codec.ReadRequestBody(argv.Interface()) if err != nil { - log.Println("rpc: tearing down", serviceMethod[0], "connection:", err) + if err == os.EOF || err == io.ErrUnexpectedEOF { + if err == io.ErrUnexpectedEOF { + log.Println("rpc:", err) + } + break + } sendResponse(sending, req, replyv.Interface(), codec, err.String()) - break + continue } go service.call(sending, mtype, req, argv, replyv, codec) } codec.Close() } +func (server *Server) readRequest(codec ServerCodec) (req *Request, service *service, mtype *methodType, err os.Error) { + // Grab the request header. + req = new(Request) + err = codec.ReadRequestHeader(req) + if err != nil { + req = nil + if err == os.EOF || err == io.ErrUnexpectedEOF { + return + } + err = os.ErrorString("rpc: server cannot decode request: " + err.String()) + return + } + + serviceMethod := strings.Split(req.ServiceMethod, ".", -1) + if len(serviceMethod) != 2 { + err = os.ErrorString("rpc: service/method request ill-formed: " + req.ServiceMethod) + return + } + // Look up the request. + server.Lock() + service = server.serviceMap[serviceMethod[0]] + server.Unlock() + if service == nil { + err = os.ErrorString("rpc: can't find service " + req.ServiceMethod) + return + } + mtype = service.method[serviceMethod[1]] + if mtype == nil { + err = os.ErrorString("rpc: can't find method " + req.ServiceMethod) + } + return +} // Accept accepts connections on the listener and serves requests // for each incoming connection. Accept blocks; the caller typically @@ -465,7 +486,8 @@ func RegisterName(name string, rcvr interface{}) os.Error { // The server calls ReadRequestHeader and ReadRequestBody in pairs // to read requests from the connection, and it calls WriteResponse to // write a response back. The server calls Close when finished with the -// connection. +// connection. ReadRequestBody may be called with a nil +// argument to force the body of the request to be read and discarded. type ServerCodec interface { ReadRequestHeader(*Request) os.Error ReadRequestBody(interface{}) os.Error diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go index 1f080faa5..05aaebceb 100644 --- a/src/pkg/rpc/server_test.go +++ b/src/pkg/rpc/server_test.go @@ -134,14 +134,25 @@ func testRPC(t *testing.T, addr string) { t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) } - args = &Args{7, 8} + // Nonexistent method + args = &Args{7, 0} reply = new(Reply) - err = client.Call("Arith.Mul", args, reply) - if err != nil { - t.Errorf("Mul: expected no error but got string %q", err.String()) + err = client.Call("Arith.BadOperation", args, reply) + // expect an error + if err == nil { + t.Error("BadOperation: expected error") + } else if !strings.HasPrefix(err.String(), "rpc: can't find method ") { + t.Errorf("BadOperation: expected can't find method error; got %q", err) } - if reply.C != args.A*args.B { - t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + + // Unknown service + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Unknown", args, reply) + if err == nil { + t.Error("expected error calling unknown service") + } else if strings.Index(err.String(), "method") < 0 { + t.Error("expected error about method; got", err) } // Out of order. @@ -178,6 +189,15 @@ func testRPC(t *testing.T, addr string) { t.Error("Div: expected divide by zero error; got", err) } + // Bad type. + reply = new(Reply) + err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use + if err == nil { + t.Error("expected error calling Arith.Add with wrong arg type") + } else if strings.Index(err.String(), "type") < 0 { + t.Error("expected error about type; got", err) + } + // Non-struct argument const Val = 12345 str := fmt.Sprint(Val) @@ -200,9 +220,19 @@ func testRPC(t *testing.T, addr string) { if str != expect { t.Errorf("String: expected %s got %s", expect, str) } + + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Mul", args, reply) + if err != nil { + t.Errorf("Mul: expected no error but got string %q", err.String()) + } + if reply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + } } -func TestHTTPRPC(t *testing.T) { +func TestHTTP(t *testing.T) { once.Do(startServer) testHTTPRPC(t, "") newOnce.Do(startNewServer) @@ -233,65 +263,6 @@ func testHTTPRPC(t *testing.T, path string) { } } -func TestCheckUnknownService(t *testing.T) { - once.Do(startServer) - - conn, err := net.Dial("tcp", "", serverAddr) - if err != nil { - t.Fatal("dialing:", err) - } - - client := NewClient(conn) - - args := &Args{7, 8} - reply := new(Reply) - err = client.Call("Unknown.Add", args, reply) - if err == nil { - t.Error("expected error calling unknown service") - } else if strings.Index(err.String(), "service") < 0 { - t.Error("expected error about service; got", err) - } -} - -func TestCheckUnknownMethod(t *testing.T) { - once.Do(startServer) - - conn, err := net.Dial("tcp", "", serverAddr) - if err != nil { - t.Fatal("dialing:", err) - } - - client := NewClient(conn) - - args := &Args{7, 8} - reply := new(Reply) - err = client.Call("Arith.Unknown", args, reply) - if err == nil { - t.Error("expected error calling unknown service") - } else if strings.Index(err.String(), "method") < 0 { - t.Error("expected error about method; got", err) - } -} - -func TestCheckBadType(t *testing.T) { - once.Do(startServer) - - conn, err := net.Dial("tcp", "", serverAddr) - if err != nil { - t.Fatal("dialing:", err) - } - - client := NewClient(conn) - - reply := new(Reply) - err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use - if err == nil { - t.Error("expected error calling Arith.Add with wrong arg type") - } else if strings.Index(err.String(), "type") < 0 { - t.Error("expected error about type; got", err) - } -} - type ArgNotPointer int type ReplyNotPointer int type ArgNotPublic int |