diff options
Diffstat (limited to 'src/pkg/rpc')
-rw-r--r-- | src/pkg/rpc/Makefile | 2 | ||||
-rw-r--r-- | src/pkg/rpc/client.go | 21 | ||||
-rw-r--r-- | src/pkg/rpc/debug.go | 36 | ||||
-rw-r--r-- | src/pkg/rpc/jsonrpc/Makefile | 2 | ||||
-rw-r--r-- | src/pkg/rpc/jsonrpc/all_test.go | 13 | ||||
-rw-r--r-- | src/pkg/rpc/jsonrpc/client.go | 17 | ||||
-rw-r--r-- | src/pkg/rpc/jsonrpc/server.go | 14 | ||||
-rw-r--r-- | src/pkg/rpc/server.go | 191 | ||||
-rw-r--r-- | src/pkg/rpc/server_test.go | 126 |
9 files changed, 308 insertions, 114 deletions
diff --git a/src/pkg/rpc/Makefile b/src/pkg/rpc/Makefile index 4757b3aae..191b10d05 100644 --- a/src/pkg/rpc/Makefile +++ b/src/pkg/rpc/Makefile @@ -2,7 +2,7 @@ # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file. -include ../../Make.$(GOARCH) +include ../../Make.inc TARG=rpc GOFILES=\ diff --git a/src/pkg/rpc/client.go b/src/pkg/rpc/client.go index d742d099f..601c49715 100644 --- a/src/pkg/rpc/client.go +++ b/src/pkg/rpc/client.go @@ -69,12 +69,12 @@ func (client *Client) send(c *Call) { // Encode and send the request. request := new(Request) client.sending.Lock() + defer client.sending.Unlock() request.Seq = c.seq request.ServiceMethod = c.ServiceMethod if err := client.codec.WriteRequest(request, c.Args); err != nil { panic("rpc: client encode error: " + err.String()) } - client.sending.Unlock() } func (client *Client) input() { @@ -94,10 +94,12 @@ func (client *Client) input() { client.pending[seq] = c, false client.mutex.Unlock() err = client.codec.ReadResponseBody(c.Reply) - // Empty strings should turn into nil os.Errors if response.Error != "" { c.Error = os.ErrorString(response.Error) + } else if err != nil { + c.Error = err } else { + // Empty strings should turn into nil os.Errors c.Error = nil } // We don't want to block here. It is the caller's responsibility to make @@ -113,7 +115,7 @@ func (client *Client) input() { } client.mutex.Unlock() if err != os.EOF || !client.closing { - log.Stderr("rpc: client protocol error:", err) + log.Println("rpc: client protocol error:", err) } } @@ -160,14 +162,21 @@ func (c *gobClientCodec) Close() os.Error { } -// DialHTTP connects to an HTTP RPC server at the specified network address. +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. func DialHTTP(network, address string) (*Client, os.Error) { + return DialHTTPPath(network, address, DefaultRPCPath) +} + +// DialHTTPPath connects to an HTTP RPC server +// at the specified network address and path. +func DialHTTPPath(network, address, path string) (*Client, os.Error) { var err os.Error conn, err := net.Dial(network, "", address) if err != nil { return nil, err } - io.WriteString(conn, "CONNECT "+rpcPath+" HTTP/1.0\n\n") + io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n") // Require successful HTTP response // before switching to RPC protocol. @@ -218,7 +227,7 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface // RPCs that will be using that channel. If the channel // is totally unbuffered, it's best not to run at all. if cap(done) == 0 { - log.Crash("rpc: done channel is unbuffered") + log.Panic("rpc: done channel is unbuffered") } } c.Done = done diff --git a/src/pkg/rpc/debug.go b/src/pkg/rpc/debug.go index 638584f49..44b32e04b 100644 --- a/src/pkg/rpc/debug.go +++ b/src/pkg/rpc/debug.go @@ -21,14 +21,14 @@ const debugText = `<html> <title>Services</title> {.repeated section @} <hr> - Service {name} + Service {Name} <hr> <table> <th align=center>Method</th><th align=center>Calls</th> - {.repeated section meth} + {.repeated section Method} <tr> - <td align=left font=fixed>{name}({m.argType}, {m.replyType}) os.Error</td> - <td align=center>{m.numCalls}</td> + <td align=left font=fixed>{Name}({Type.ArgType}, {Type.ReplyType}) os.Error</td> + <td align=center>{Type.NumCalls}</td> </tr> {.end} </table> @@ -39,30 +39,34 @@ const debugText = `<html> var debug = template.MustParse(debugText, nil) type debugMethod struct { - m *methodType - name string + Type *methodType + Name string } type methodArray []debugMethod type debugService struct { - s *service - name string - meth methodArray + Service *service + Name string + Method methodArray } type serviceArray []debugService func (s serviceArray) Len() int { return len(s) } -func (s serviceArray) Less(i, j int) bool { return s[i].name < s[j].name } +func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name } func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func (m methodArray) Len() int { return len(m) } -func (m methodArray) Less(i, j int) bool { return m[i].name < m[j].name } +func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name } func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] } +type debugHTTP struct { + *Server +} + // Runs at /debug/rpc -func debugHTTP(c *http.Conn, req *http.Request) { +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Build a sorted version of the data. var services = make(serviceArray, len(server.serviceMap)) i := 0 @@ -71,16 +75,16 @@ func debugHTTP(c *http.Conn, req *http.Request) { services[i] = debugService{service, sname, make(methodArray, len(service.method))} j := 0 for mname, method := range service.method { - services[i].meth[j] = debugMethod{method, mname} + services[i].Method[j] = debugMethod{method, mname} j++ } - sort.Sort(services[i].meth) + sort.Sort(services[i].Method) i++ } server.Unlock() sort.Sort(services) - err := debug.Execute(services, c) + err := debug.Execute(services, w) if err != nil { - fmt.Fprintln(c, "rpc: error executing template:", err.String()) + fmt.Fprintln(w, "rpc: error executing template:", err.String()) } } diff --git a/src/pkg/rpc/jsonrpc/Makefile b/src/pkg/rpc/jsonrpc/Makefile index 1a4fd2e92..b9a1ac2f7 100644 --- a/src/pkg/rpc/jsonrpc/Makefile +++ b/src/pkg/rpc/jsonrpc/Makefile @@ -2,7 +2,7 @@ # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file. -include ../../../Make.$(GOARCH) +include ../../../Make.inc TARG=rpc/jsonrpc GOFILES=\ diff --git a/src/pkg/rpc/jsonrpc/all_test.go b/src/pkg/rpc/jsonrpc/all_test.go index e94c594da..764ee7ff3 100644 --- a/src/pkg/rpc/jsonrpc/all_test.go +++ b/src/pkg/rpc/jsonrpc/all_test.go @@ -53,7 +53,7 @@ func TestServer(t *testing.T) { type addResp struct { Id interface{} "id" Result Reply "result" - Error string "error" + Error interface{} "error" } cli, srv := net.Pipe() @@ -69,7 +69,7 @@ func TestServer(t *testing.T) { if err != nil { t.Fatalf("Decode: %s", err) } - if resp.Error != "" { + if resp.Error != nil { t.Fatalf("resp.Error: %s", resp.Error) } if resp.Id.(string) != string(i) { @@ -79,6 +79,15 @@ func TestServer(t *testing.T) { t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C) } } + + fmt.Fprintf(cli, "{}\n") + var resp addResp + if err := dec.Decode(&resp); err != nil { + t.Fatalf("Decode after empty: %s", err) + } + if resp.Error == nil { + t.Fatalf("Expected error, got nil") + } } func TestClient(t *testing.T) { diff --git a/src/pkg/rpc/jsonrpc/client.go b/src/pkg/rpc/jsonrpc/client.go index ed2b4ed37..dcaa69f9d 100644 --- a/src/pkg/rpc/jsonrpc/client.go +++ b/src/pkg/rpc/jsonrpc/client.go @@ -7,6 +7,7 @@ package jsonrpc import ( + "fmt" "io" "json" "net" @@ -61,13 +62,13 @@ func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) os.Error { type clientResponse struct { Id uint64 "id" Result *json.RawMessage "result" - Error string "error" + Error interface{} "error" } func (r *clientResponse) reset() { r.Id = 0 r.Result = nil - r.Error = "" + r.Error = nil } func (c *clientCodec) ReadResponseHeader(r *rpc.Response) os.Error { @@ -81,8 +82,18 @@ func (c *clientCodec) ReadResponseHeader(r *rpc.Response) os.Error { c.pending[c.resp.Id] = "", false c.mutex.Unlock() + r.Error = "" r.Seq = c.resp.Id - r.Error = c.resp.Error + if c.resp.Error != nil { + x, ok := c.resp.Error.(string) + if !ok { + return fmt.Errorf("invalid error %v", c.resp.Error) + } + if x == "" { + x = "unspecified error" + } + r.Error = x + } return nil } diff --git a/src/pkg/rpc/jsonrpc/server.go b/src/pkg/rpc/jsonrpc/server.go index 9f3472a39..bf53bda8d 100644 --- a/src/pkg/rpc/jsonrpc/server.go +++ b/src/pkg/rpc/jsonrpc/server.go @@ -61,7 +61,7 @@ func (r *serverRequest) reset() { type serverResponse struct { Id *json.RawMessage "id" Result interface{} "result" - Error string "error" + Error interface{} "error" } func (c *serverCodec) ReadRequestHeader(r *rpc.Request) os.Error { @@ -94,6 +94,8 @@ func (c *serverCodec) ReadRequestBody(x interface{}) os.Error { return json.Unmarshal(*c.req.Params, ¶ms) } +var null = json.RawMessage([]byte("null")) + func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) os.Error { var resp serverResponse c.mutex.Lock() @@ -105,9 +107,17 @@ func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) os.Error { c.pending[r.Seq] = nil, false c.mutex.Unlock() + if b == nil { + // Invalid request so no id. Use JSON null. + b = &null + } resp.Id = b resp.Result = x - resp.Error = r.Error + if r.Error == "" { + resp.Error = nil + } else { + resp.Error = r.Error + } return c.enc.Encode(resp) } diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index d14f6ded2..5c50bcc3a 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -3,9 +3,9 @@ // license that can be found in the LICENSE file. /* - The rpc package provides access to the public methods of an object across a + The rpc package provides access to the exported methods of an object across a network or other I/O connection. A server registers an object, making it visible - as a service with the name of the type of the object. After registration, public + as a service with the name of the type of the object. After registration, exported methods of the object will be accessible remotely. A server may register multiple objects (services) of different types but it is an error to register multiple objects of the same type. @@ -13,8 +13,8 @@ Only methods that satisfy these criteria will be made available for remote access; other methods will be ignored: - - the method receiver and name are publicly visible, that is, begin with an upper case letter. - - the method has two arguments, both pointers to publicly visible types. + - the method receiver and name are exported, that is, begin with an upper case letter. + - the method has two arguments, both pointers to exported types. - the method has return type os.Error. The method's first argument represents the arguments provided by the caller; the @@ -123,6 +123,12 @@ import ( "utf8" ) +const ( + // Defaults used by HandleHTTP + DefaultRPCPath = "/_goRPC_" + DefaultDebugPath = "/debug/rpc" +) + // Precompute the reflect type for os.Error. Can't use os.Error directly // because Typeof takes an empty interface value. This is annoying. var unusedError *os.Error @@ -131,8 +137,8 @@ var typeOfOsError = reflect.Typeof(unusedError).(*reflect.PtrType).Elem() type methodType struct { sync.Mutex // protects counters method reflect.Method - argType *reflect.PtrType - replyType *reflect.PtrType + ArgType *reflect.PtrType + ReplyType *reflect.PtrType numCalls uint } @@ -166,23 +172,46 @@ type ClientInfo struct { RemoteAddr string } -type serverType struct { +// Server represents an RPC Server. +type Server struct { sync.Mutex // protects the serviceMap serviceMap map[string]*service } -// This variable is a global whose "public" methods are really private methods -// called from the global functions of this package: rpc.Register, rpc.ServeConn, etc. -// For example, rpc.Register() calls server.add(). -var server = &serverType{serviceMap: make(map[string]*service)} +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{serviceMap: make(map[string]*service)} +} -// Is this a publicly visible - upper case - name? -func isPublic(name string) bool { +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// Is this an exported - upper case - name? +func isExported(name string) bool { rune, _ := utf8.DecodeRuneInString(name) return unicode.IsUpper(rune) } -func (server *serverType) register(rcvr interface{}) os.Error { +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method +// - two arguments, both pointers to exported structs +// - one return value, of type os.Error +// It returns an error if the receiver is not an exported type or has no +// suitable methods. +// The client accesses each method using a string of the form "Type.Method", +// where Type is the receiver's concrete type. +func (server *Server) Register(rcvr interface{}) os.Error { + return server.register(rcvr, "", false) +} + +// RegisterName is like Register but uses the provided name for the type +// instead of the receiver's concrete type. +func (server *Server) RegisterName(name string, rcvr interface{}) os.Error { + return server.register(rcvr, name, true) +} + +func (server *Server) register(rcvr interface{}, name string, useName bool) os.Error { server.Lock() defer server.Unlock() if server.serviceMap == nil { @@ -192,12 +221,15 @@ func (server *serverType) register(rcvr interface{}) os.Error { s.typ = reflect.Typeof(rcvr) s.rcvr = reflect.NewValue(rcvr) sname := reflect.Indirect(s.rcvr).Type().Name() + if useName { + sname = name + } if sname == "" { log.Exit("rpc: no service name for type", s.typ.String()) } - if s.typ.PkgPath() != "" && !isPublic(sname) { - s := "rpc Register: type " + sname + " is not public" - log.Stderr(s) + if s.typ.PkgPath() != "" && !isExported(sname) && !useName { + s := "rpc Register: type " + sname + " is not exported" + log.Print(s) return os.ErrorString(s) } if _, present := server.serviceMap[sname]; present { @@ -211,54 +243,54 @@ func (server *serverType) register(rcvr interface{}) os.Error { method := s.typ.Method(m) mtype := method.Type mname := method.Name - if mtype.PkgPath() != "" && !isPublic(mname) { + if mtype.PkgPath() != "" || !isExported(mname) { continue } // Method needs three ins: receiver, *args, *reply. if mtype.NumIn() != 3 { - log.Stderr("method", mname, "has wrong number of ins:", mtype.NumIn()) + log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) continue } argType, ok := mtype.In(1).(*reflect.PtrType) if !ok { - log.Stderr(mname, "arg type not a pointer:", mtype.In(1)) + log.Println(mname, "arg type not a pointer:", mtype.In(1)) continue } replyType, ok := mtype.In(2).(*reflect.PtrType) if !ok { - log.Stderr(mname, "reply type not a pointer:", mtype.In(2)) + log.Println(mname, "reply type not a pointer:", mtype.In(2)) continue } - if argType.Elem().PkgPath() != "" && !isPublic(argType.Elem().Name()) { - log.Stderr(mname, "argument type not public:", argType) + if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) { + log.Println(mname, "argument type not exported:", argType) continue } - if replyType.Elem().PkgPath() != "" && !isPublic(replyType.Elem().Name()) { - log.Stderr(mname, "reply type not public:", replyType) + if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) { + log.Println(mname, "reply type not exported:", replyType) continue } if mtype.NumIn() == 4 { t := mtype.In(3) if t != reflect.Typeof((*ClientInfo)(nil)) { - log.Stderr(mname, "last argument not *ClientInfo") + log.Println(mname, "last argument not *ClientInfo") continue } } // Method needs one out: os.Error. if mtype.NumOut() != 1 { - log.Stderr("method", mname, "has wrong number of outs:", mtype.NumOut()) + log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) continue } if returnType := mtype.Out(0); returnType != typeOfOsError { - log.Stderr("method", mname, "returns", returnType.String(), "not os.Error") + log.Println("method", mname, "returns", returnType.String(), "not os.Error") continue } - s.method[mname] = &methodType{method: method, argType: argType, replyType: replyType} + s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} } if len(s.method) == 0 { - s := "rpc Register: type " + sname + " has no public methods of suitable type" - log.Stderr(s) + s := "rpc Register: type " + sname + " has no exported methods of suitable type" + log.Print(s) return os.ErrorString(s) } server.serviceMap[s.name] = s @@ -289,11 +321,18 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec Se sending.Lock() err := codec.WriteResponse(resp, reply) if err != nil { - log.Stderr("rpc: writing response: ", err) + log.Println("rpc: writing response:", err) } sending.Unlock() } +func (m *methodType) NumCalls() (n uint) { + m.Lock() + n = m.numCalls + m.Unlock() + return n +} + func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { mtype.Lock() mtype.numCalls++ @@ -335,7 +374,19 @@ func (c *gobServerCodec) Close() os.Error { return c.rwc.Close() } -func (server *serverType) input(codec ServerCodec) { + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +// ServeConn uses the gob wire format (see package gob) on the +// connection. To use an alternate codec, use ServeCodec. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + server.ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)}) +} + +// ServeCodec is like ServeConn but uses the specified codec to +// decode requests and encode responses. +func (server *Server) ServeCodec(codec ServerCodec) { sending := new(sync.Mutex) for { // Grab the request header. @@ -344,7 +395,7 @@ func (server *serverType) input(codec ServerCodec) { if err != nil { if err == os.EOF || err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF { - log.Stderr("rpc: ", err) + log.Println("rpc:", err) } break } @@ -374,11 +425,11 @@ func (server *serverType) input(codec ServerCodec) { continue } // Decode the argument value. - argv := _new(mtype.argType) - replyv := _new(mtype.replyType) + argv := _new(mtype.ArgType) + replyv := _new(mtype.ReplyType) err = codec.ReadRequestBody(argv.Interface()) if err != nil { - log.Stderr("rpc: tearing down", serviceMethod[0], "connection:", err) + log.Println("rpc: tearing down", serviceMethod[0], "connection:", err) sendResponse(sending, req, replyv.Interface(), codec, err.String()) break } @@ -387,24 +438,27 @@ func (server *serverType) input(codec ServerCodec) { codec.Close() } -func (server *serverType) accept(lis net.Listener) { +// Accept accepts connections on the listener and serves requests +// for each incoming connection. Accept blocks; the caller typically +// invokes it in a go statement. +func (server *Server) Accept(lis net.Listener) { for { conn, err := lis.Accept() if err != nil { log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit? } - go ServeConn(conn) + go server.ServeConn(conn) } } -// Register publishes in the server the set of methods of the -// receiver value that satisfy the following conditions: -// - public method -// - two arguments, both pointers to public structs -// - one return value of type os.Error -// It returns an error if the receiver is not public or has no -// suitable methods. -func Register(rcvr interface{}) os.Error { return server.register(rcvr) } +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) os.Error { return DefaultServer.Register(rcvr) } + +// RegisterName is like Register but uses the provided name for the type +// instead of the receiver's concrete type. +func RegisterName(name string, rcvr interface{}) os.Error { + return DefaultServer.RegisterName(name, rcvr) +} // A ServerCodec implements reading of RPC requests and writing of // RPC responses for the server side of an RPC session. @@ -420,50 +474,57 @@ type ServerCodec interface { Close() os.Error } -// ServeConn runs the server on a single connection. +// ServeConn runs the DefaultServer on a single connection. // ServeConn blocks, serving the connection until the client hangs up. // The caller typically invokes ServeConn in a go statement. // ServeConn uses the gob wire format (see package gob) on the // connection. To use an alternate codec, use ServeCodec. func ServeConn(conn io.ReadWriteCloser) { - ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)}) + DefaultServer.ServeConn(conn) } // ServeCodec is like ServeConn but uses the specified codec to // decode requests and encode responses. func ServeCodec(codec ServerCodec) { - server.input(codec) + DefaultServer.ServeCodec(codec) } // Accept accepts connections on the listener and serves requests -// for each incoming connection. Accept blocks; the caller typically -// invokes it in a go statement. -func Accept(lis net.Listener) { server.accept(lis) } +// to DefaultServer for each incoming connection. +// Accept blocks; the caller typically invokes it in a go statement. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } // Can connect to RPC service using HTTP CONNECT to rpcPath. -var rpcPath string = "/_goRPC_" -var debugPath string = "/debug/rpc" var connected = "200 Connected to Go RPC" -func serveHTTP(c *http.Conn, req *http.Request) { +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.Method != "CONNECT" { - c.SetHeader("Content-Type", "text/plain; charset=utf-8") - c.WriteHeader(http.StatusMethodNotAllowed) - io.WriteString(c, "405 must CONNECT to "+rpcPath+"\n") + w.SetHeader("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + io.WriteString(w, "405 must CONNECT\n") return } - conn, _, err := c.Hijack() + conn, _, err := w.Hijack() if err != nil { - log.Stderr("rpc hijacking ", c.RemoteAddr, ": ", err.String()) + log.Print("rpc hijacking ", w.RemoteAddr(), ": ", err.String()) return } io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") - ServeConn(conn) + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP(rpcPath, debugPath string) { + http.Handle(rpcPath, server) + http.Handle(debugPath, debugHTTP{server}) } -// HandleHTTP registers an HTTP handler for RPC messages. +// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer +// on DefaultRPCPath and a debugging handler on DefaultDebugPath. // It is still necessary to invoke http.Serve(), typically in a go statement. func HandleHTTP() { - http.Handle(rpcPath, http.HandlerFunc(serveHTTP)) - http.Handle(debugPath, http.HandlerFunc(debugHTTP)) + DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath) } diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go index e502db4e3..355d51ce4 100644 --- a/src/pkg/rpc/server_test.go +++ b/src/pkg/rpc/server_test.go @@ -9,17 +9,23 @@ import ( "http" "log" "net" - "once" "os" "strings" + "sync" "testing" + "time" ) -var serverAddr string -var httpServerAddr string - -const second = 1e9 +var ( + serverAddr, newServerAddr string + httpServerAddr string + once, newOnce, httpOnce sync.Once +) +const ( + second = 1e9 + newHttpPath = "/foo" +) type Args struct { A, B int @@ -63,32 +69,56 @@ func (t *Arith) Error(args *Args, reply *Reply) os.Error { panic("ERROR") } -func startServer() { - Register(new(Arith)) - +func listenTCP() (net.Listener, string) { l, e := net.Listen("tcp", "127.0.0.1:0") // any available address if e != nil { log.Exitf("net.Listen tcp :0: %v", e) } - serverAddr = l.Addr().String() - log.Stderr("Test RPC server listening on ", serverAddr) + return l, l.Addr().String() +} + +func startServer() { + Register(new(Arith)) + + var l net.Listener + l, serverAddr = listenTCP() + log.Println("Test RPC server listening on", serverAddr) go Accept(l) HandleHTTP() - l, e = net.Listen("tcp", "127.0.0.1:0") // any available address - if e != nil { - log.Stderrf("net.Listen tcp :0: %v", e) - os.Exit(1) - } + httpOnce.Do(startHttpServer) +} + +func startNewServer() { + s := NewServer() + s.Register(new(Arith)) + + var l net.Listener + l, newServerAddr = listenTCP() + log.Println("NewServer test RPC server listening on", newServerAddr) + go Accept(l) + + s.HandleHTTP(newHttpPath, "/bar") + httpOnce.Do(startHttpServer) +} + +func startHttpServer() { + var l net.Listener + l, httpServerAddr = listenTCP() httpServerAddr = l.Addr().String() - log.Stderr("Test HTTP RPC server listening on ", httpServerAddr) + log.Println("Test HTTP RPC server listening on", httpServerAddr) go http.Serve(l, nil) } func TestRPC(t *testing.T) { once.Do(startServer) + testRPC(t, serverAddr) + newOnce.Do(startNewServer) + testRPC(t, newServerAddr) +} - client, err := Dial("tcp", serverAddr) +func testRPC(t *testing.T, addr string) { + client, err := Dial("tcp", addr) if err != nil { t.Fatal("dialing", err) } @@ -174,8 +204,19 @@ func TestRPC(t *testing.T) { func TestHTTPRPC(t *testing.T) { once.Do(startServer) + testHTTPRPC(t, "") + newOnce.Do(startNewServer) + testHTTPRPC(t, newHttpPath) +} - client, err := DialHTTP("tcp", httpServerAddr) +func testHTTPRPC(t *testing.T, path string) { + var client *Client + var err os.Error + if path == "" { + client, err = DialHTTP("tcp", httpServerAddr) + } else { + client, err = DialHTTPPath("tcp", httpServerAddr, path) + } if err != nil { t.Fatal("dialing", err) } @@ -292,3 +333,52 @@ func TestRegistrationError(t *testing.T) { t.Errorf("expected error registering ReplyNotPublic") } } + +type WriteFailCodec int + +func (WriteFailCodec) WriteRequest(*Request, interface{}) os.Error { + // the panic caused by this error used to not unlock a lock. + return os.NewError("fail") +} + +func (WriteFailCodec) ReadResponseHeader(*Response) os.Error { + time.Sleep(60e9) + panic("unreachable") +} + +func (WriteFailCodec) ReadResponseBody(interface{}) os.Error { + time.Sleep(60e9) + panic("unreachable") +} + +func (WriteFailCodec) Close() os.Error { + return nil +} + +func TestSendDeadlock(t *testing.T) { + client := NewClientWithCodec(WriteFailCodec(0)) + + done := make(chan bool) + go func() { + testSendDeadlock(client) + testSendDeadlock(client) + done <- true + }() + for i := 0; i < 50; i++ { + time.Sleep(100 * 1e6) + _, ok := <-done + if ok { + return + } + } + t.Fatal("deadlock") +} + +func testSendDeadlock(client *Client) { + defer func() { + recover() + }() + args := &Args{7, 8} + reply := new(Reply) + client.Call("Arith.Add", args, reply) +} |