summaryrefslogtreecommitdiff
path: root/src/pkg/rpc
diff options
context:
space:
mode:
authorOndřej Surý <ondrej@sury.org>2011-09-13 13:13:40 +0200
committerOndřej Surý <ondrej@sury.org>2011-09-13 13:13:40 +0200
commit5ff4c17907d5b19510a62e08fd8d3b11e62b431d (patch)
treec0650497e988f47be9c6f2324fa692a52dea82e1 /src/pkg/rpc
parent80f18fc933cf3f3e829c5455a1023d69f7b86e52 (diff)
downloadgolang-upstream/60.tar.gz
Imported Upstream version 60upstream/60
Diffstat (limited to 'src/pkg/rpc')
-rw-r--r--src/pkg/rpc/Makefile13
-rw-r--r--src/pkg/rpc/client.go286
-rw-r--r--src/pkg/rpc/debug.go90
-rw-r--r--src/pkg/rpc/jsonrpc/Makefile12
-rw-r--r--src/pkg/rpc/jsonrpc/all_test.go156
-rw-r--r--src/pkg/rpc/jsonrpc/client.go124
-rw-r--r--src/pkg/rpc/jsonrpc/server.go136
-rw-r--r--src/pkg/rpc/server.go637
-rw-r--r--src/pkg/rpc/server_test.go501
9 files changed, 1955 insertions, 0 deletions
diff --git a/src/pkg/rpc/Makefile b/src/pkg/rpc/Makefile
new file mode 100644
index 000000000..191b10d05
--- /dev/null
+++ b/src/pkg/rpc/Makefile
@@ -0,0 +1,13 @@
+# Copyright 2009 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../Make.inc
+
+TARG=rpc
+GOFILES=\
+ client.go\
+ debug.go\
+ server.go\
+
+include ../../Make.pkg
diff --git a/src/pkg/rpc/client.go b/src/pkg/rpc/client.go
new file mode 100644
index 000000000..4acfdf6d9
--- /dev/null
+++ b/src/pkg/rpc/client.go
@@ -0,0 +1,286 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+import (
+ "bufio"
+ "gob"
+ "http"
+ "io"
+ "log"
+ "net"
+ "os"
+ "sync"
+)
+
+// ServerError represents an error that has been returned from
+// the remote side of the RPC connection.
+type ServerError string
+
+func (e ServerError) String() string {
+ return string(e)
+}
+
+var ErrShutdown = os.NewError("connection is shut down")
+
+// Call represents an active RPC.
+type Call struct {
+ ServiceMethod string // The name of the service and method to call.
+ Args interface{} // The argument to the function (*struct).
+ Reply interface{} // The reply from the function (*struct).
+ Error os.Error // After completion, the error status.
+ Done chan *Call // Strobes when call is complete; value is the error status.
+ seq uint64
+}
+
+// Client represents an RPC Client.
+// There may be multiple outstanding Calls associated
+// with a single Client.
+type Client struct {
+ mutex sync.Mutex // protects pending, seq, request
+ sending sync.Mutex
+ request Request
+ seq uint64
+ codec ClientCodec
+ pending map[uint64]*Call
+ closing bool
+ shutdown bool
+}
+
+// A ClientCodec implements writing of RPC requests and
+// reading of RPC responses for the client side of an RPC session.
+// The client calls WriteRequest to write a request to the connection
+// and calls ReadResponseHeader and ReadResponseBody in pairs
+// to read responses. The client calls Close when finished with the
+// connection. ReadResponseBody may be called with a nil
+// argument to force the body of the response to be read and then
+// discarded.
+type ClientCodec interface {
+ WriteRequest(*Request, interface{}) os.Error
+ ReadResponseHeader(*Response) os.Error
+ ReadResponseBody(interface{}) os.Error
+
+ Close() os.Error
+}
+
+func (client *Client) send(c *Call) {
+ // Register this call.
+ client.mutex.Lock()
+ if client.shutdown {
+ c.Error = ErrShutdown
+ client.mutex.Unlock()
+ c.done()
+ return
+ }
+ c.seq = client.seq
+ client.seq++
+ client.pending[c.seq] = c
+ client.mutex.Unlock()
+
+ // Encode and send the request.
+ client.sending.Lock()
+ defer client.sending.Unlock()
+ client.request.Seq = c.seq
+ client.request.ServiceMethod = c.ServiceMethod
+ if err := client.codec.WriteRequest(&client.request, c.Args); err != nil {
+ panic("rpc: client encode error: " + err.String())
+ }
+}
+
+func (client *Client) input() {
+ var err os.Error
+ var response Response
+ for err == nil {
+ response = Response{}
+ err = client.codec.ReadResponseHeader(&response)
+ if err != nil {
+ if err == os.EOF && !client.closing {
+ err = io.ErrUnexpectedEOF
+ }
+ break
+ }
+ seq := response.Seq
+ client.mutex.Lock()
+ c := client.pending[seq]
+ client.pending[seq] = c, false
+ client.mutex.Unlock()
+
+ if response.Error == "" {
+ err = client.codec.ReadResponseBody(c.Reply)
+ if err != nil {
+ c.Error = os.NewError("reading body " + err.String())
+ }
+ } else {
+ // We've got an error response. Give this to the request;
+ // any subsequent requests will get the ReadResponseBody
+ // error if there is one.
+ c.Error = ServerError(response.Error)
+ err = client.codec.ReadResponseBody(nil)
+ if err != nil {
+ err = os.NewError("reading error body: " + err.String())
+ }
+ }
+ c.done()
+ }
+ // Terminate pending calls.
+ client.mutex.Lock()
+ client.shutdown = true
+ for _, call := range client.pending {
+ call.Error = err
+ call.done()
+ }
+ client.mutex.Unlock()
+ if err != os.EOF || !client.closing {
+ log.Println("rpc: client protocol error:", err)
+ }
+}
+
+func (call *Call) done() {
+ select {
+ case call.Done <- call:
+ // ok
+ default:
+ // We don't want to block here. It is the caller's responsibility to make
+ // sure the channel has enough buffer space. See comment in Go().
+ }
+}
+
+// NewClient returns a new Client to handle requests to the
+// set of services at the other end of the connection.
+// It adds a buffer to the write side of the connection so
+// the header and payload are sent as a unit.
+func NewClient(conn io.ReadWriteCloser) *Client {
+ encBuf := bufio.NewWriter(conn)
+ client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
+ return NewClientWithCodec(client)
+}
+
+// NewClientWithCodec is like NewClient but uses the specified
+// codec to encode requests and decode responses.
+func NewClientWithCodec(codec ClientCodec) *Client {
+ client := &Client{
+ codec: codec,
+ pending: make(map[uint64]*Call),
+ }
+ go client.input()
+ return client
+}
+
+type gobClientCodec struct {
+ rwc io.ReadWriteCloser
+ dec *gob.Decoder
+ enc *gob.Encoder
+ encBuf *bufio.Writer
+}
+
+func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err os.Error) {
+ if err = c.enc.Encode(r); err != nil {
+ return
+ }
+ if err = c.enc.Encode(body); err != nil {
+ return
+ }
+ return c.encBuf.Flush()
+}
+
+func (c *gobClientCodec) ReadResponseHeader(r *Response) os.Error {
+ return c.dec.Decode(r)
+}
+
+func (c *gobClientCodec) ReadResponseBody(body interface{}) os.Error {
+ return c.dec.Decode(body)
+}
+
+func (c *gobClientCodec) Close() os.Error {
+ return c.rwc.Close()
+}
+
+// DialHTTP connects to an HTTP RPC server at the specified network address
+// listening on the default HTTP RPC path.
+func DialHTTP(network, address string) (*Client, os.Error) {
+ return DialHTTPPath(network, address, DefaultRPCPath)
+}
+
+// DialHTTPPath connects to an HTTP RPC server
+// at the specified network address and path.
+func DialHTTPPath(network, address, path string) (*Client, os.Error) {
+ var err os.Error
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
+
+ // Require successful HTTP response
+ // before switching to RPC protocol.
+ resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
+ if err == nil && resp.Status == connected {
+ return NewClient(conn), nil
+ }
+ if err == nil {
+ err = os.NewError("unexpected HTTP response: " + resp.Status)
+ }
+ conn.Close()
+ return nil, &net.OpError{"dial-http", network + " " + address, nil, err}
+}
+
+// Dial connects to an RPC server at the specified network address.
+func Dial(network, address string) (*Client, os.Error) {
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(conn), nil
+}
+
+func (client *Client) Close() os.Error {
+ client.mutex.Lock()
+ if client.shutdown || client.closing {
+ client.mutex.Unlock()
+ return ErrShutdown
+ }
+ client.closing = true
+ client.mutex.Unlock()
+ return client.codec.Close()
+}
+
+// Go invokes the function asynchronously. It returns the Call structure representing
+// the invocation. The done channel will signal when the call is complete by returning
+// the same Call object. If done is nil, Go will allocate a new channel.
+// If non-nil, done must be buffered or Go will deliberately crash.
+func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
+ c := new(Call)
+ c.ServiceMethod = serviceMethod
+ c.Args = args
+ c.Reply = reply
+ if done == nil {
+ done = make(chan *Call, 10) // buffered.
+ } else {
+ // If caller passes done != nil, it must arrange that
+ // done has enough buffer for the number of simultaneous
+ // RPCs that will be using that channel. If the channel
+ // is totally unbuffered, it's best not to run at all.
+ if cap(done) == 0 {
+ log.Panic("rpc: done channel is unbuffered")
+ }
+ }
+ c.Done = done
+ if client.shutdown {
+ c.Error = ErrShutdown
+ c.done()
+ return c
+ }
+ client.send(c)
+ return c
+}
+
+// Call invokes the named function, waits for it to complete, and returns its error status.
+func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
+ if client.shutdown {
+ return ErrShutdown
+ }
+ call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
+ return call.Error
+}
diff --git a/src/pkg/rpc/debug.go b/src/pkg/rpc/debug.go
new file mode 100644
index 000000000..7e3e6f6e5
--- /dev/null
+++ b/src/pkg/rpc/debug.go
@@ -0,0 +1,90 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+/*
+ Some HTML presented at http://machine:port/debug/rpc
+ Lists services, their methods, and some statistics, still rudimentary.
+*/
+
+import (
+ "fmt"
+ "http"
+ "sort"
+ "template"
+)
+
+const debugText = `<html>
+ <body>
+ <title>Services</title>
+ {{range .}}
+ <hr>
+ Service {{.Name}}
+ <hr>
+ <table>
+ <th align=center>Method</th><th align=center>Calls</th>
+ {{range .Method}}
+ <tr>
+ <td align=left font=fixed>{{.Name}}({{.Type.ArgType}}, {{.Type.ReplyType}}) os.Error</td>
+ <td align=center>{{.Type.NumCalls}}</td>
+ </tr>
+ {{end}}
+ </table>
+ {{end}}
+ </body>
+ </html>`
+
+var debug = template.Must(template.New("RPC debug").Parse(debugText))
+
+type debugMethod struct {
+ Type *methodType
+ Name string
+}
+
+type methodArray []debugMethod
+
+type debugService struct {
+ Service *service
+ Name string
+ Method methodArray
+}
+
+type serviceArray []debugService
+
+func (s serviceArray) Len() int { return len(s) }
+func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name }
+func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+func (m methodArray) Len() int { return len(m) }
+func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name }
+func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
+
+type debugHTTP struct {
+ *Server
+}
+
+// Runs at /debug/rpc
+func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ // Build a sorted version of the data.
+ var services = make(serviceArray, len(server.serviceMap))
+ i := 0
+ server.mu.Lock()
+ for sname, service := range server.serviceMap {
+ services[i] = debugService{service, sname, make(methodArray, len(service.method))}
+ j := 0
+ for mname, method := range service.method {
+ services[i].Method[j] = debugMethod{method, mname}
+ j++
+ }
+ sort.Sort(services[i].Method)
+ i++
+ }
+ server.mu.Unlock()
+ sort.Sort(services)
+ err := debug.Execute(w, services)
+ if err != nil {
+ fmt.Fprintln(w, "rpc: error executing template:", err.String())
+ }
+}
diff --git a/src/pkg/rpc/jsonrpc/Makefile b/src/pkg/rpc/jsonrpc/Makefile
new file mode 100644
index 000000000..b9a1ac2f7
--- /dev/null
+++ b/src/pkg/rpc/jsonrpc/Makefile
@@ -0,0 +1,12 @@
+# Copyright 2010 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../Make.inc
+
+TARG=rpc/jsonrpc
+GOFILES=\
+ client.go\
+ server.go\
+
+include ../../../Make.pkg
diff --git a/src/pkg/rpc/jsonrpc/all_test.go b/src/pkg/rpc/jsonrpc/all_test.go
new file mode 100644
index 000000000..c1a9e8ecb
--- /dev/null
+++ b/src/pkg/rpc/jsonrpc/all_test.go
@@ -0,0 +1,156 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package jsonrpc
+
+import (
+ "fmt"
+ "json"
+ "net"
+ "os"
+ "rpc"
+ "testing"
+)
+
+type Args struct {
+ A, B int
+}
+
+type Reply struct {
+ C int
+}
+
+type Arith int
+
+func (t *Arith) Add(args *Args, reply *Reply) os.Error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+func (t *Arith) Mul(args *Args, reply *Reply) os.Error {
+ reply.C = args.A * args.B
+ return nil
+}
+
+func (t *Arith) Div(args *Args, reply *Reply) os.Error {
+ if args.B == 0 {
+ return os.NewError("divide by zero")
+ }
+ reply.C = args.A / args.B
+ return nil
+}
+
+func (t *Arith) Error(args *Args, reply *Reply) os.Error {
+ panic("ERROR")
+}
+
+func init() {
+ rpc.Register(new(Arith))
+}
+
+func TestServer(t *testing.T) {
+ type addResp struct {
+ Id interface{} `json:"id"`
+ Result Reply `json:"result"`
+ Error interface{} `json:"error"`
+ }
+
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ // Send hand-coded requests to server, parse responses.
+ for i := 0; i < 10; i++ {
+ fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
+ var resp addResp
+ err := dec.Decode(&resp)
+ if err != nil {
+ t.Fatalf("Decode: %s", err)
+ }
+ if resp.Error != nil {
+ t.Fatalf("resp.Error: %s", resp.Error)
+ }
+ if resp.Id.(string) != string(i) {
+ t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(i))
+ }
+ if resp.Result.C != 2*i+1 {
+ t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
+ }
+ }
+
+ fmt.Fprintf(cli, "{}\n")
+ var resp addResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after empty: %s", err)
+ }
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+
+func TestClient(t *testing.T) {
+ // Assume server is okay (TestServer is above).
+ // Test client against server.
+ cli, srv := net.Pipe()
+ go ServeConn(srv)
+
+ client := NewClient(cli)
+ defer client.Close()
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Mul", args, reply)
+ if err != nil {
+ t.Errorf("Mul: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ }
+
+ // Out of order.
+ args = &Args{7, 8}
+ mulReply := new(Reply)
+ mulCall := client.Go("Arith.Mul", args, mulReply, nil)
+ addReply := new(Reply)
+ addCall := client.Go("Arith.Add", args, addReply, nil)
+
+ addCall = <-addCall.Done
+ if addCall.Error != nil {
+ t.Errorf("Add: expected no error but got string %q", addCall.Error.String())
+ }
+ if addReply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+ }
+
+ mulCall = <-mulCall.Done
+ if mulCall.Error != nil {
+ t.Errorf("Mul: expected no error but got string %q", mulCall.Error.String())
+ }
+ if mulReply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+ }
+
+ // Error test
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.Div", args, reply)
+ // expect an error: zero divide
+ if err == nil {
+ t.Error("Div: expected error")
+ } else if err.String() != "divide by zero" {
+ t.Error("Div: expected divide by zero error; got", err)
+ }
+}
diff --git a/src/pkg/rpc/jsonrpc/client.go b/src/pkg/rpc/jsonrpc/client.go
new file mode 100644
index 000000000..577d0ce42
--- /dev/null
+++ b/src/pkg/rpc/jsonrpc/client.go
@@ -0,0 +1,124 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package jsonrpc implements a JSON-RPC ClientCodec and ServerCodec
+// for the rpc package.
+package jsonrpc
+
+import (
+ "fmt"
+ "io"
+ "json"
+ "net"
+ "os"
+ "rpc"
+ "sync"
+)
+
+type clientCodec struct {
+ dec *json.Decoder // for reading JSON values
+ enc *json.Encoder // for writing JSON values
+ c io.Closer
+
+ // temporary work space
+ req clientRequest
+ resp clientResponse
+
+ // JSON-RPC responses include the request id but not the request method.
+ // Package rpc expects both.
+ // We save the request method in pending when sending a request
+ // and then look it up by request ID when filling out the rpc Response.
+ mutex sync.Mutex // protects pending
+ pending map[uint64]string // map request id to method name
+}
+
+// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn.
+func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
+ return &clientCodec{
+ dec: json.NewDecoder(conn),
+ enc: json.NewEncoder(conn),
+ c: conn,
+ pending: make(map[uint64]string),
+ }
+}
+
+type clientRequest struct {
+ Method string `json:"method"`
+ Params [1]interface{} `json:"params"`
+ Id uint64 `json:"id"`
+}
+
+func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) os.Error {
+ c.mutex.Lock()
+ c.pending[r.Seq] = r.ServiceMethod
+ c.mutex.Unlock()
+ c.req.Method = r.ServiceMethod
+ c.req.Params[0] = param
+ c.req.Id = r.Seq
+ return c.enc.Encode(&c.req)
+}
+
+type clientResponse struct {
+ Id uint64 `json:"id"`
+ Result *json.RawMessage `json:"result"`
+ Error interface{} `json:"error"`
+}
+
+func (r *clientResponse) reset() {
+ r.Id = 0
+ r.Result = nil
+ r.Error = nil
+}
+
+func (c *clientCodec) ReadResponseHeader(r *rpc.Response) os.Error {
+ c.resp.reset()
+ if err := c.dec.Decode(&c.resp); err != nil {
+ return err
+ }
+
+ c.mutex.Lock()
+ r.ServiceMethod = c.pending[c.resp.Id]
+ c.pending[c.resp.Id] = "", false
+ c.mutex.Unlock()
+
+ r.Error = ""
+ r.Seq = c.resp.Id
+ if c.resp.Error != nil {
+ x, ok := c.resp.Error.(string)
+ if !ok {
+ return fmt.Errorf("invalid error %v", c.resp.Error)
+ }
+ if x == "" {
+ x = "unspecified error"
+ }
+ r.Error = x
+ }
+ return nil
+}
+
+func (c *clientCodec) ReadResponseBody(x interface{}) os.Error {
+ if x == nil {
+ return nil
+ }
+ return json.Unmarshal(*c.resp.Result, x)
+}
+
+func (c *clientCodec) Close() os.Error {
+ return c.c.Close()
+}
+
+// NewClient returns a new rpc.Client to handle requests to the
+// set of services at the other end of the connection.
+func NewClient(conn io.ReadWriteCloser) *rpc.Client {
+ return rpc.NewClientWithCodec(NewClientCodec(conn))
+}
+
+// Dial connects to a JSON-RPC server at the specified network address.
+func Dial(network, address string) (*rpc.Client, os.Error) {
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(conn), err
+}
diff --git a/src/pkg/rpc/jsonrpc/server.go b/src/pkg/rpc/jsonrpc/server.go
new file mode 100644
index 000000000..9801fdf22
--- /dev/null
+++ b/src/pkg/rpc/jsonrpc/server.go
@@ -0,0 +1,136 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package jsonrpc
+
+import (
+ "io"
+ "json"
+ "os"
+ "rpc"
+ "sync"
+)
+
+type serverCodec struct {
+ dec *json.Decoder // for reading JSON values
+ enc *json.Encoder // for writing JSON values
+ c io.Closer
+
+ // temporary work space
+ req serverRequest
+ resp serverResponse
+
+ // JSON-RPC clients can use arbitrary json values as request IDs.
+ // Package rpc expects uint64 request IDs.
+ // We assign uint64 sequence numbers to incoming requests
+ // but save the original request ID in the pending map.
+ // When rpc responds, we use the sequence number in
+ // the response to find the original request ID.
+ mutex sync.Mutex // protects seq, pending
+ seq uint64
+ pending map[uint64]*json.RawMessage
+}
+
+// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn.
+func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
+ return &serverCodec{
+ dec: json.NewDecoder(conn),
+ enc: json.NewEncoder(conn),
+ c: conn,
+ pending: make(map[uint64]*json.RawMessage),
+ }
+}
+
+type serverRequest struct {
+ Method string `json:"method"`
+ Params *json.RawMessage `json:"params"`
+ Id *json.RawMessage `json:"id"`
+}
+
+func (r *serverRequest) reset() {
+ r.Method = ""
+ if r.Params != nil {
+ *r.Params = (*r.Params)[0:0]
+ }
+ if r.Id != nil {
+ *r.Id = (*r.Id)[0:0]
+ }
+}
+
+type serverResponse struct {
+ Id *json.RawMessage `json:"id"`
+ Result interface{} `json:"result"`
+ Error interface{} `json:"error"`
+}
+
+func (c *serverCodec) ReadRequestHeader(r *rpc.Request) os.Error {
+ c.req.reset()
+ if err := c.dec.Decode(&c.req); err != nil {
+ return err
+ }
+ r.ServiceMethod = c.req.Method
+
+ // JSON request id can be any JSON value;
+ // RPC package expects uint64. Translate to
+ // internal uint64 and save JSON on the side.
+ c.mutex.Lock()
+ c.seq++
+ c.pending[c.seq] = c.req.Id
+ c.req.Id = nil
+ r.Seq = c.seq
+ c.mutex.Unlock()
+
+ return nil
+}
+
+func (c *serverCodec) ReadRequestBody(x interface{}) os.Error {
+ if x == nil {
+ return nil
+ }
+ // JSON params is array value.
+ // RPC params is struct.
+ // Unmarshal into array containing struct for now.
+ // Should think about making RPC more general.
+ var params [1]interface{}
+ params[0] = x
+ return json.Unmarshal(*c.req.Params, &params)
+}
+
+var null = json.RawMessage([]byte("null"))
+
+func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) os.Error {
+ var resp serverResponse
+ c.mutex.Lock()
+ b, ok := c.pending[r.Seq]
+ if !ok {
+ c.mutex.Unlock()
+ return os.NewError("invalid sequence number in response")
+ }
+ c.pending[r.Seq] = nil, false
+ c.mutex.Unlock()
+
+ if b == nil {
+ // Invalid request so no id. Use JSON null.
+ b = &null
+ }
+ resp.Id = b
+ resp.Result = x
+ if r.Error == "" {
+ resp.Error = nil
+ } else {
+ resp.Error = r.Error
+ }
+ return c.enc.Encode(resp)
+}
+
+func (c *serverCodec) Close() os.Error {
+ return c.c.Close()
+}
+
+// ServeConn runs the JSON-RPC server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+func ServeConn(conn io.ReadWriteCloser) {
+ rpc.ServeCodec(NewServerCodec(conn))
+}
diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go
new file mode 100644
index 000000000..745074428
--- /dev/null
+++ b/src/pkg/rpc/server.go
@@ -0,0 +1,637 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ Package rpc provides access to the exported methods of an object across a
+ network or other I/O connection. A server registers an object, making it visible
+ as a service with the name of the type of the object. After registration, exported
+ methods of the object will be accessible remotely. A server may register multiple
+ objects (services) of different types but it is an error to register multiple
+ objects of the same type.
+
+ Only methods that satisfy these criteria will be made available for remote access;
+ other methods will be ignored:
+
+ - the method name is exported, that is, begins with an upper case letter.
+ - the method receiver is exported or local (defined in the package
+ registering the service).
+ - the method has two arguments, both exported or local types.
+ - the method's second argument is a pointer.
+ - the method has return type os.Error.
+
+ The method's first argument represents the arguments provided by the caller; the
+ second argument represents the result parameters to be returned to the caller.
+ The method's return value, if non-nil, is passed back as a string that the client
+ sees as an os.ErrorString.
+
+ The server may handle requests on a single connection by calling ServeConn. More
+ typically it will create a network listener and call Accept or, for an HTTP
+ listener, HandleHTTP and http.Serve.
+
+ A client wishing to use the service establishes a connection and then invokes
+ NewClient on the connection. The convenience function Dial (DialHTTP) performs
+ both steps for a raw network connection (an HTTP connection). The resulting
+ Client object has two methods, Call and Go, that specify the service and method to
+ call, a pointer containing the arguments, and a pointer to receive the result
+ parameters.
+
+ Call waits for the remote call to complete; Go launches the call asynchronously
+ and returns a channel that will signal completion.
+
+ Package "gob" is used to transport the data.
+
+ Here is a simple example. A server wishes to export an object of type Arith:
+
+ package server
+
+ type Args struct {
+ A, B int
+ }
+
+ type Quotient struct {
+ Quo, Rem int
+ }
+
+ type Arith int
+
+ func (t *Arith) Multiply(args *Args, reply *int) os.Error {
+ *reply = args.A * args.B
+ return nil
+ }
+
+ func (t *Arith) Divide(args *Args, quo *Quotient) os.Error {
+ if args.B == 0 {
+ return os.ErrorString("divide by zero")
+ }
+ quo.Quo = args.A / args.B
+ quo.Rem = args.A % args.B
+ return nil
+ }
+
+ The server calls (for HTTP service):
+
+ arith := new(Arith)
+ rpc.Register(arith)
+ rpc.HandleHTTP()
+ l, e := net.Listen("tcp", ":1234")
+ if e != nil {
+ log.Fatal("listen error:", e)
+ }
+ go http.Serve(l, nil)
+
+ At this point, clients can see a service "Arith" with methods "Arith.Multiply" and
+ "Arith.Divide". To invoke one, a client first dials the server:
+
+ client, err := rpc.DialHTTP("tcp", serverAddress + ":1234")
+ if err != nil {
+ log.Fatal("dialing:", err)
+ }
+
+ Then it can make a remote call:
+
+ // Synchronous call
+ args := &server.Args{7,8}
+ var reply int
+ err = client.Call("Arith.Multiply", args, &reply)
+ if err != nil {
+ log.Fatal("arith error:", err)
+ }
+ fmt.Printf("Arith: %d*%d=%d", args.A, args.B, *reply)
+
+ or
+
+ // Asynchronous call
+ quotient := new(Quotient)
+ divCall := client.Go("Arith.Divide", args, &quotient, nil)
+ replyCall := <-divCall.Done // will be equal to divCall
+ // check errors, print, etc.
+
+ A server implementation will often provide a simple, type-safe wrapper for the
+ client.
+*/
+package rpc
+
+import (
+ "bufio"
+ "gob"
+ "http"
+ "log"
+ "io"
+ "net"
+ "os"
+ "reflect"
+ "strings"
+ "sync"
+ "unicode"
+ "utf8"
+)
+
+const (
+ // Defaults used by HandleHTTP
+ DefaultRPCPath = "/_goRPC_"
+ DefaultDebugPath = "/debug/rpc"
+)
+
+// Precompute the reflect type for os.Error. Can't use os.Error directly
+// because Typeof takes an empty interface value. This is annoying.
+var unusedError *os.Error
+var typeOfOsError = reflect.TypeOf(unusedError).Elem()
+
+type methodType struct {
+ sync.Mutex // protects counters
+ method reflect.Method
+ ArgType reflect.Type
+ ReplyType reflect.Type
+ numCalls uint
+}
+
+type service struct {
+ name string // name of service
+ rcvr reflect.Value // receiver of methods for the service
+ typ reflect.Type // type of the receiver
+ method map[string]*methodType // registered methods
+}
+
+// Request is a header written before every RPC call. It is used internally
+// but documented here as an aid to debugging, such as when analyzing
+// network traffic.
+type Request struct {
+ ServiceMethod string // format: "Service.Method"
+ Seq uint64 // sequence number chosen by client
+ next *Request // for free list in Server
+}
+
+// Response is a header written before every RPC return. It is used internally
+// but documented here as an aid to debugging, such as when analyzing
+// network traffic.
+type Response struct {
+ ServiceMethod string // echoes that of the Request
+ Seq uint64 // echoes that of the request
+ Error string // error, if any.
+ next *Response // for free list in Server
+}
+
+// Server represents an RPC Server.
+type Server struct {
+ mu sync.Mutex // protects the serviceMap
+ serviceMap map[string]*service
+ reqLock sync.Mutex // protects freeReq
+ freeReq *Request
+ respLock sync.Mutex // protects freeResp
+ freeResp *Response
+}
+
+// NewServer returns a new Server.
+func NewServer() *Server {
+ return &Server{serviceMap: make(map[string]*service)}
+}
+
+// DefaultServer is the default instance of *Server.
+var DefaultServer = NewServer()
+
+// Is this an exported - upper case - name?
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// Is this type exported or a builtin?
+func isExportedOrBuiltinType(t reflect.Type) bool {
+ for t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+ // PkgPath will be non-empty even for an exported type,
+ // so we need to check the type name as well.
+ return isExported(t.Name()) || t.PkgPath() == ""
+}
+
+// Register publishes in the server the set of methods of the
+// receiver value that satisfy the following conditions:
+// - exported method
+// - two arguments, both pointers to exported structs
+// - one return value, of type os.Error
+// It returns an error if the receiver is not an exported type or has no
+// suitable methods.
+// The client accesses each method using a string of the form "Type.Method",
+// where Type is the receiver's concrete type.
+func (server *Server) Register(rcvr interface{}) os.Error {
+ return server.register(rcvr, "", false)
+}
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func (server *Server) RegisterName(name string, rcvr interface{}) os.Error {
+ return server.register(rcvr, name, true)
+}
+
+func (server *Server) register(rcvr interface{}, name string, useName bool) os.Error {
+ server.mu.Lock()
+ defer server.mu.Unlock()
+ if server.serviceMap == nil {
+ server.serviceMap = make(map[string]*service)
+ }
+ s := new(service)
+ s.typ = reflect.TypeOf(rcvr)
+ s.rcvr = reflect.ValueOf(rcvr)
+ sname := reflect.Indirect(s.rcvr).Type().Name()
+ if useName {
+ sname = name
+ }
+ if sname == "" {
+ log.Fatal("rpc: no service name for type", s.typ.String())
+ }
+ if !isExported(sname) && !useName {
+ s := "rpc Register: type " + sname + " is not exported"
+ log.Print(s)
+ return os.NewError(s)
+ }
+ if _, present := server.serviceMap[sname]; present {
+ return os.NewError("rpc: service already defined: " + sname)
+ }
+ s.name = sname
+ s.method = make(map[string]*methodType)
+
+ // Install the methods
+ for m := 0; m < s.typ.NumMethod(); m++ {
+ method := s.typ.Method(m)
+ mtype := method.Type
+ mname := method.Name
+ if method.PkgPath != "" {
+ continue
+ }
+ // Method needs three ins: receiver, *args, *reply.
+ if mtype.NumIn() != 3 {
+ log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
+ continue
+ }
+ // First arg need not be a pointer.
+ argType := mtype.In(1)
+ if !isExportedOrBuiltinType(argType) {
+ log.Println(mname, "argument type not exported or local:", argType)
+ continue
+ }
+ // Second arg must be a pointer.
+ replyType := mtype.In(2)
+ if replyType.Kind() != reflect.Ptr {
+ log.Println("method", mname, "reply type not a pointer:", replyType)
+ continue
+ }
+ if !isExportedOrBuiltinType(replyType) {
+ log.Println("method", mname, "reply type not exported or local:", replyType)
+ continue
+ }
+ // Method needs one out: os.Error.
+ if mtype.NumOut() != 1 {
+ log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
+ continue
+ }
+ if returnType := mtype.Out(0); returnType != typeOfOsError {
+ log.Println("method", mname, "returns", returnType.String(), "not os.Error")
+ continue
+ }
+ s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
+ }
+
+ if len(s.method) == 0 {
+ s := "rpc Register: type " + sname + " has no exported methods of suitable type"
+ log.Print(s)
+ return os.NewError(s)
+ }
+ server.serviceMap[s.name] = s
+ return nil
+}
+
+// A value sent as a placeholder for the response when the server receives an invalid request.
+type InvalidRequest struct{}
+
+var invalidRequest = InvalidRequest{}
+
+func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
+ resp := server.getResponse()
+ // Encode the response header
+ resp.ServiceMethod = req.ServiceMethod
+ if errmsg != "" {
+ resp.Error = errmsg
+ reply = invalidRequest
+ }
+ resp.Seq = req.Seq
+ sending.Lock()
+ err := codec.WriteResponse(resp, reply)
+ if err != nil {
+ log.Println("rpc: writing response:", err)
+ }
+ sending.Unlock()
+ server.freeResponse(resp)
+}
+
+func (m *methodType) NumCalls() (n uint) {
+ m.Lock()
+ n = m.numCalls
+ m.Unlock()
+ return n
+}
+
+func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
+ mtype.Lock()
+ mtype.numCalls++
+ mtype.Unlock()
+ function := mtype.method.Func
+ // Invoke the method, providing a new value for the reply.
+ returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
+ // The return value for the method is an os.Error.
+ errInter := returnValues[0].Interface()
+ errmsg := ""
+ if errInter != nil {
+ errmsg = errInter.(os.Error).String()
+ }
+ server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
+ server.freeRequest(req)
+}
+
+type gobServerCodec struct {
+ rwc io.ReadWriteCloser
+ dec *gob.Decoder
+ enc *gob.Encoder
+ encBuf *bufio.Writer
+}
+
+func (c *gobServerCodec) ReadRequestHeader(r *Request) os.Error {
+ return c.dec.Decode(r)
+}
+
+func (c *gobServerCodec) ReadRequestBody(body interface{}) os.Error {
+ return c.dec.Decode(body)
+}
+
+func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err os.Error) {
+ if err = c.enc.Encode(r); err != nil {
+ return
+ }
+ if err = c.enc.Encode(body); err != nil {
+ return
+ }
+ return c.encBuf.Flush()
+}
+
+func (c *gobServerCodec) Close() os.Error {
+ return c.rwc.Close()
+}
+
+// ServeConn runs the server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection. To use an alternate codec, use ServeCodec.
+func (server *Server) ServeConn(conn io.ReadWriteCloser) {
+ buf := bufio.NewWriter(conn)
+ srv := &gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(buf), buf}
+ server.ServeCodec(srv)
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func (server *Server) ServeCodec(codec ServerCodec) {
+ sending := new(sync.Mutex)
+ for {
+ service, mtype, req, argv, replyv, err := server.readRequest(codec)
+ if err != nil {
+ if err != os.EOF {
+ log.Println("rpc:", err)
+ }
+ if err == os.EOF || err == io.ErrUnexpectedEOF {
+ break
+ }
+ // send a response if we actually managed to read a header.
+ if req != nil {
+ server.sendResponse(sending, req, invalidRequest, codec, err.String())
+ server.freeRequest(req)
+ }
+ continue
+ }
+ go service.call(server, sending, mtype, req, argv, replyv, codec)
+ }
+ codec.Close()
+}
+
+// ServeRequest is like ServeCodec but synchronously serves a single request.
+// It does not close the codec upon completion.
+func (server *Server) ServeRequest(codec ServerCodec) os.Error {
+ sending := new(sync.Mutex)
+ service, mtype, req, argv, replyv, err := server.readRequest(codec)
+ if err != nil {
+ if err == os.EOF || err == io.ErrUnexpectedEOF {
+ return err
+ }
+ // send a response if we actually managed to read a header.
+ if req != nil {
+ server.sendResponse(sending, req, invalidRequest, codec, err.String())
+ server.freeRequest(req)
+ }
+ return err
+ }
+ service.call(server, sending, mtype, req, argv, replyv, codec)
+ return nil
+}
+
+func (server *Server) getRequest() *Request {
+ server.reqLock.Lock()
+ req := server.freeReq
+ if req == nil {
+ req = new(Request)
+ } else {
+ server.freeReq = req.next
+ *req = Request{}
+ }
+ server.reqLock.Unlock()
+ return req
+}
+
+func (server *Server) freeRequest(req *Request) {
+ server.reqLock.Lock()
+ req.next = server.freeReq
+ server.freeReq = req
+ server.reqLock.Unlock()
+}
+
+func (server *Server) getResponse() *Response {
+ server.respLock.Lock()
+ resp := server.freeResp
+ if resp == nil {
+ resp = new(Response)
+ } else {
+ server.freeResp = resp.next
+ *resp = Response{}
+ }
+ server.respLock.Unlock()
+ return resp
+}
+
+func (server *Server) freeResponse(resp *Response) {
+ server.respLock.Lock()
+ resp.next = server.freeResp
+ server.freeResp = resp
+ server.respLock.Unlock()
+}
+
+func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, err os.Error) {
+ service, mtype, req, err = server.readRequestHeader(codec)
+ if err != nil {
+ if err == os.EOF || err == io.ErrUnexpectedEOF {
+ return
+ }
+ // discard body
+ codec.ReadRequestBody(nil)
+ return
+ }
+
+ // Decode the argument value.
+ argIsValue := false // if true, need to indirect before calling.
+ if mtype.ArgType.Kind() == reflect.Ptr {
+ argv = reflect.New(mtype.ArgType.Elem())
+ } else {
+ argv = reflect.New(mtype.ArgType)
+ argIsValue = true
+ }
+ // argv guaranteed to be a pointer now.
+ if err = codec.ReadRequestBody(argv.Interface()); err != nil {
+ return
+ }
+ if argIsValue {
+ argv = argv.Elem()
+ }
+
+ replyv = reflect.New(mtype.ReplyType.Elem())
+ return
+}
+
+func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, err os.Error) {
+ // Grab the request header.
+ req = server.getRequest()
+ err = codec.ReadRequestHeader(req)
+ if err != nil {
+ req = nil
+ if err == os.EOF || err == io.ErrUnexpectedEOF {
+ return
+ }
+ err = os.NewError("rpc: server cannot decode request: " + err.String())
+ return
+ }
+
+ serviceMethod := strings.Split(req.ServiceMethod, ".")
+ if len(serviceMethod) != 2 {
+ err = os.NewError("rpc: service/method request ill-formed: " + req.ServiceMethod)
+ return
+ }
+ // Look up the request.
+ server.mu.Lock()
+ service = server.serviceMap[serviceMethod[0]]
+ server.mu.Unlock()
+ if service == nil {
+ err = os.NewError("rpc: can't find service " + req.ServiceMethod)
+ return
+ }
+ mtype = service.method[serviceMethod[1]]
+ if mtype == nil {
+ err = os.NewError("rpc: can't find method " + req.ServiceMethod)
+ }
+ return
+}
+
+// Accept accepts connections on the listener and serves requests
+// for each incoming connection. Accept blocks; the caller typically
+// invokes it in a go statement.
+func (server *Server) Accept(lis net.Listener) {
+ for {
+ conn, err := lis.Accept()
+ if err != nil {
+ log.Fatal("rpc.Serve: accept:", err.String()) // TODO(r): exit?
+ }
+ go server.ServeConn(conn)
+ }
+}
+
+// Register publishes the receiver's methods in the DefaultServer.
+func Register(rcvr interface{}) os.Error { return DefaultServer.Register(rcvr) }
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func RegisterName(name string, rcvr interface{}) os.Error {
+ return DefaultServer.RegisterName(name, rcvr)
+}
+
+// A ServerCodec implements reading of RPC requests and writing of
+// RPC responses for the server side of an RPC session.
+// The server calls ReadRequestHeader and ReadRequestBody in pairs
+// to read requests from the connection, and it calls WriteResponse to
+// write a response back. The server calls Close when finished with the
+// connection. ReadRequestBody may be called with a nil
+// argument to force the body of the request to be read and discarded.
+type ServerCodec interface {
+ ReadRequestHeader(*Request) os.Error
+ ReadRequestBody(interface{}) os.Error
+ WriteResponse(*Response, interface{}) os.Error
+
+ Close() os.Error
+}
+
+// ServeConn runs the DefaultServer on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection. To use an alternate codec, use ServeCodec.
+func ServeConn(conn io.ReadWriteCloser) {
+ DefaultServer.ServeConn(conn)
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func ServeCodec(codec ServerCodec) {
+ DefaultServer.ServeCodec(codec)
+}
+
+// ServeRequest is like ServeCodec but synchronously serves a single request.
+// It does not close the codec upon completion.
+func ServeRequest(codec ServerCodec) os.Error {
+ return DefaultServer.ServeRequest(codec)
+}
+
+// Accept accepts connections on the listener and serves requests
+// to DefaultServer for each incoming connection.
+// Accept blocks; the caller typically invokes it in a go statement.
+func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
+
+// Can connect to RPC service using HTTP CONNECT to rpcPath.
+var connected = "200 Connected to Go RPC"
+
+// ServeHTTP implements an http.Handler that answers RPC requests.
+func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ if req.Method != "CONNECT" {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ io.WriteString(w, "405 must CONNECT\n")
+ return
+ }
+ conn, _, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.String())
+ return
+ }
+ io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
+ server.ServeConn(conn)
+}
+
+// HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
+// and a debugging handler on debugPath.
+// It is still necessary to invoke http.Serve(), typically in a go statement.
+func (server *Server) HandleHTTP(rpcPath, debugPath string) {
+ http.Handle(rpcPath, server)
+ http.Handle(debugPath, debugHTTP{server})
+}
+
+// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
+// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
+// It is still necessary to invoke http.Serve(), typically in a go statement.
+func HandleHTTP() {
+ DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
+}
diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go
new file mode 100644
index 000000000..e7bbfbe97
--- /dev/null
+++ b/src/pkg/rpc/server_test.go
@@ -0,0 +1,501 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+import (
+ "fmt"
+ "http/httptest"
+ "io"
+ "log"
+ "net"
+ "os"
+ "runtime"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+var (
+ newServer *Server
+ serverAddr, newServerAddr string
+ httpServerAddr string
+ once, newOnce, httpOnce sync.Once
+)
+
+const (
+ second = 1e9
+ newHttpPath = "/foo"
+)
+
+type Args struct {
+ A, B int
+}
+
+type Reply struct {
+ C int
+}
+
+type Arith int
+
+// Some of Arith's methods have value args, some have pointer args. That's deliberate.
+
+func (t *Arith) Add(args Args, reply *Reply) os.Error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+func (t *Arith) Mul(args *Args, reply *Reply) os.Error {
+ reply.C = args.A * args.B
+ return nil
+}
+
+func (t *Arith) Div(args Args, reply *Reply) os.Error {
+ if args.B == 0 {
+ return os.NewError("divide by zero")
+ }
+ reply.C = args.A / args.B
+ return nil
+}
+
+func (t *Arith) String(args *Args, reply *string) os.Error {
+ *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
+ return nil
+}
+
+func (t *Arith) Scan(args string, reply *Reply) (err os.Error) {
+ _, err = fmt.Sscan(args, &reply.C)
+ return
+}
+
+func (t *Arith) Error(args *Args, reply *Reply) os.Error {
+ panic("ERROR")
+}
+
+func listenTCP() (net.Listener, string) {
+ l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
+ if e != nil {
+ log.Fatalf("net.Listen tcp :0: %v", e)
+ }
+ return l, l.Addr().String()
+}
+
+func startServer() {
+ Register(new(Arith))
+
+ var l net.Listener
+ l, serverAddr = listenTCP()
+ log.Println("Test RPC server listening on", serverAddr)
+ go Accept(l)
+
+ HandleHTTP()
+ httpOnce.Do(startHttpServer)
+}
+
+func startNewServer() {
+ newServer = NewServer()
+ newServer.Register(new(Arith))
+
+ var l net.Listener
+ l, newServerAddr = listenTCP()
+ log.Println("NewServer test RPC server listening on", newServerAddr)
+ go Accept(l)
+
+ newServer.HandleHTTP(newHttpPath, "/bar")
+ httpOnce.Do(startHttpServer)
+}
+
+func startHttpServer() {
+ server := httptest.NewServer(nil)
+ httpServerAddr = server.Listener.Addr().String()
+ log.Println("Test HTTP RPC server listening on", httpServerAddr)
+}
+
+func TestRPC(t *testing.T) {
+ once.Do(startServer)
+ testRPC(t, serverAddr)
+ newOnce.Do(startNewServer)
+ testRPC(t, newServerAddr)
+}
+
+func testRPC(t *testing.T, addr string) {
+ client, err := Dial("tcp", addr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ // Nonexistent method
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.BadOperation", args, reply)
+ // expect an error
+ if err == nil {
+ t.Error("BadOperation: expected error")
+ } else if !strings.HasPrefix(err.String(), "rpc: can't find method ") {
+ t.Errorf("BadOperation: expected can't find method error; got %q", err)
+ }
+
+ // Unknown service
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Unknown", args, reply)
+ if err == nil {
+ t.Error("expected error calling unknown service")
+ } else if strings.Index(err.String(), "method") < 0 {
+ t.Error("expected error about method; got", err)
+ }
+
+ // Out of order.
+ args = &Args{7, 8}
+ mulReply := new(Reply)
+ mulCall := client.Go("Arith.Mul", args, mulReply, nil)
+ addReply := new(Reply)
+ addCall := client.Go("Arith.Add", args, addReply, nil)
+
+ addCall = <-addCall.Done
+ if addCall.Error != nil {
+ t.Errorf("Add: expected no error but got string %q", addCall.Error.String())
+ }
+ if addReply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+ }
+
+ mulCall = <-mulCall.Done
+ if mulCall.Error != nil {
+ t.Errorf("Mul: expected no error but got string %q", mulCall.Error.String())
+ }
+ if mulReply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+ }
+
+ // Error test
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.Div", args, reply)
+ // expect an error: zero divide
+ if err == nil {
+ t.Error("Div: expected error")
+ } else if err.String() != "divide by zero" {
+ t.Error("Div: expected divide by zero error; got", err)
+ }
+
+ // Bad type.
+ reply = new(Reply)
+ err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
+ if err == nil {
+ t.Error("expected error calling Arith.Add with wrong arg type")
+ } else if strings.Index(err.String(), "type") < 0 {
+ t.Error("expected error about type; got", err)
+ }
+
+ // Non-struct argument
+ const Val = 12345
+ str := fmt.Sprint(Val)
+ reply = new(Reply)
+ err = client.Call("Arith.Scan", &str, reply)
+ if err != nil {
+ t.Errorf("Scan: expected no error but got string %q", err.String())
+ } else if reply.C != Val {
+ t.Errorf("Scan: expected %d got %d", Val, reply.C)
+ }
+
+ // Non-struct reply
+ args = &Args{27, 35}
+ str = ""
+ err = client.Call("Arith.String", args, &str)
+ if err != nil {
+ t.Errorf("String: expected no error but got string %q", err.String())
+ }
+ expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
+ if str != expect {
+ t.Errorf("String: expected %s got %s", expect, str)
+ }
+
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Mul", args, reply)
+ if err != nil {
+ t.Errorf("Mul: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ }
+}
+
+func TestHTTP(t *testing.T) {
+ once.Do(startServer)
+ testHTTPRPC(t, "")
+ newOnce.Do(startNewServer)
+ testHTTPRPC(t, newHttpPath)
+}
+
+func testHTTPRPC(t *testing.T, path string) {
+ var client *Client
+ var err os.Error
+ if path == "" {
+ client, err = DialHTTP("tcp", httpServerAddr)
+ } else {
+ client, err = DialHTTPPath("tcp", httpServerAddr, path)
+ }
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+}
+
+// CodecEmulator provides a client-like api and a ServerCodec interface.
+// Can be used to test ServeRequest.
+type CodecEmulator struct {
+ server *Server
+ serviceMethod string
+ args *Args
+ reply *Reply
+ err os.Error
+}
+
+func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) os.Error {
+ codec.serviceMethod = serviceMethod
+ codec.args = args
+ codec.reply = reply
+ codec.err = nil
+ var serverError os.Error
+ if codec.server == nil {
+ serverError = ServeRequest(codec)
+ } else {
+ serverError = codec.server.ServeRequest(codec)
+ }
+ if codec.err == nil && serverError != nil {
+ codec.err = serverError
+ }
+ return codec.err
+}
+
+func (codec *CodecEmulator) ReadRequestHeader(req *Request) os.Error {
+ req.ServiceMethod = codec.serviceMethod
+ req.Seq = 0
+ return nil
+}
+
+func (codec *CodecEmulator) ReadRequestBody(argv interface{}) os.Error {
+ if codec.args == nil {
+ return io.ErrUnexpectedEOF
+ }
+ *(argv.(*Args)) = *codec.args
+ return nil
+}
+
+func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) os.Error {
+ if resp.Error != "" {
+ codec.err = os.NewError(resp.Error)
+ }
+ *codec.reply = *(reply.(*Reply))
+ return nil
+}
+
+func (codec *CodecEmulator) Close() os.Error {
+ return nil
+}
+
+func TestServeRequest(t *testing.T) {
+ once.Do(startServer)
+ testServeRequest(t, nil)
+ newOnce.Do(startNewServer)
+ testServeRequest(t, newServer)
+}
+
+func testServeRequest(t *testing.T, server *Server) {
+ client := CodecEmulator{server: server}
+
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ err = client.Call("Arith.Add", nil, reply)
+ if err == nil {
+ t.Errorf("expected error calling Arith.Add with nil arg")
+ }
+}
+
+type ReplyNotPointer int
+type ArgNotPublic int
+type ReplyNotPublic int
+type local struct{}
+
+func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) os.Error {
+ return nil
+}
+
+func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) os.Error {
+ return nil
+}
+
+func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) os.Error {
+ return nil
+}
+
+// Check that registration handles lots of bad methods and a type with no suitable methods.
+func TestRegistrationError(t *testing.T) {
+ err := Register(new(ReplyNotPointer))
+ if err == nil {
+ t.Errorf("expected error registering ReplyNotPointer")
+ }
+ err = Register(new(ArgNotPublic))
+ if err == nil {
+ t.Errorf("expected error registering ArgNotPublic")
+ }
+ err = Register(new(ReplyNotPublic))
+ if err == nil {
+ t.Errorf("expected error registering ReplyNotPublic")
+ }
+}
+
+type WriteFailCodec int
+
+func (WriteFailCodec) WriteRequest(*Request, interface{}) os.Error {
+ // the panic caused by this error used to not unlock a lock.
+ return os.NewError("fail")
+}
+
+func (WriteFailCodec) ReadResponseHeader(*Response) os.Error {
+ time.Sleep(120e9)
+ panic("unreachable")
+}
+
+func (WriteFailCodec) ReadResponseBody(interface{}) os.Error {
+ time.Sleep(120e9)
+ panic("unreachable")
+}
+
+func (WriteFailCodec) Close() os.Error {
+ return nil
+}
+
+func TestSendDeadlock(t *testing.T) {
+ client := NewClientWithCodec(WriteFailCodec(0))
+
+ done := make(chan bool)
+ go func() {
+ testSendDeadlock(client)
+ testSendDeadlock(client)
+ done <- true
+ }()
+ select {
+ case <-done:
+ return
+ case <-time.After(5e9):
+ t.Fatal("deadlock")
+ }
+}
+
+func testSendDeadlock(client *Client) {
+ defer func() {
+ recover()
+ }()
+ args := &Args{7, 8}
+ reply := new(Reply)
+ client.Call("Arith.Add", args, reply)
+}
+
+func dialDirect() (*Client, os.Error) {
+ return Dial("tcp", serverAddr)
+}
+
+func dialHTTP() (*Client, os.Error) {
+ return DialHTTP("tcp", httpServerAddr)
+}
+
+func countMallocs(dial func() (*Client, os.Error), t *testing.T) uint64 {
+ once.Do(startServer)
+ client, err := dial()
+ if err != nil {
+ t.Fatal("error dialing", err)
+ }
+ args := &Args{7, 8}
+ reply := new(Reply)
+ runtime.UpdateMemStats()
+ mallocs := 0 - runtime.MemStats.Mallocs
+ const count = 100
+ for i := 0; i < count; i++ {
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+ }
+ runtime.UpdateMemStats()
+ mallocs += runtime.MemStats.Mallocs
+ return mallocs / count
+}
+
+func TestCountMallocs(t *testing.T) {
+ fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t))
+}
+
+func TestCountMallocsOverHTTP(t *testing.T) {
+ fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t))
+}
+
+func benchmarkEndToEnd(dial func() (*Client, os.Error), b *testing.B) {
+ b.StopTimer()
+ once.Do(startServer)
+ client, err := dial()
+ if err != nil {
+ fmt.Println("error dialing", err)
+ return
+ }
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ fmt.Printf("Add: expected no error but got string %q", err.String())
+ break
+ }
+ if reply.C != args.A+args.B {
+ fmt.Printf("Add: expected %d got %d", reply.C, args.A+args.B)
+ break
+ }
+ }
+}
+
+func BenchmarkEndToEnd(b *testing.B) {
+ benchmarkEndToEnd(dialDirect, b)
+}
+
+func BenchmarkEndToEndHTTP(b *testing.B) {
+ benchmarkEndToEnd(dialHTTP, b)
+}