diff options
Diffstat (limited to 'src/pkg/rpc/server.go')
| -rw-r--r-- | src/pkg/rpc/server.go | 45 |
1 files changed, 34 insertions, 11 deletions
diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index 4c957597b..7df89a8d7 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -158,6 +158,12 @@ type Response struct { Error string // error, if any. } +// ClientInfo records information about an RPC client connection. +type ClientInfo struct { + LocalAddr string + RemoteAddr string +} + type serverType struct { sync.Mutex // protects the serviceMap serviceMap map[string]*service @@ -208,7 +214,7 @@ func (server *serverType) register(rcvr interface{}) os.Error { } // Method needs three ins: receiver, *args, *reply. // The args and reply must be structs until gobs are more general. - if mtype.NumIn() != 3 { + if mtype.NumIn() != 3 && mtype.NumIn() != 4 { log.Stderr("method", mname, "has wrong number of ins:", mtype.NumIn()) continue } @@ -238,6 +244,13 @@ func (server *serverType) register(rcvr interface{}) os.Error { log.Stderr(mname, "reply type not public:", replyType) continue } + if mtype.NumIn() == 4 { + t := mtype.In(3) + if t != reflect.Typeof((*ClientInfo)(nil)) { + log.Stderr(mname, "last argument not *ClientInfo") + continue + } + } // Method needs one out: os.Error. if mtype.NumOut() != 1 { log.Stderr("method", mname, "has wrong number of outs:", mtype.NumOut()) @@ -288,13 +301,19 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec Se sending.Unlock() } -func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { +func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec, ci *ClientInfo) { mtype.Lock() mtype.numCalls++ mtype.Unlock() function := mtype.method.Func // Invoke the method, providing a new value for the reply. - returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv}) + var args []reflect.Value + if mtype.method.Type.NumIn() == 3 { + args = []reflect.Value{s.rcvr, argv, replyv} + } else { + args = []reflect.Value{s.rcvr, argv, replyv, reflect.NewValue(ci)} + } + returnValues := function.Call(args) // The return value for the method is an os.Error. errInter := returnValues[0].Interface() errmsg := "" @@ -329,7 +348,7 @@ func (c *gobServerCodec) Close() os.Error { return c.rwc.Close() } -func (server *serverType) input(codec ServerCodec) { +func (server *serverType) input(codec ServerCodec, ci *ClientInfo) { sending := new(sync.Mutex) for { // Grab the request header. @@ -376,7 +395,7 @@ func (server *serverType) input(codec ServerCodec) { sendResponse(sending, req, replyv.Interface(), codec, err.String()) break } - go service.call(sending, mtype, req, argv, replyv, codec) + go service.call(sending, mtype, req, argv, replyv, codec, ci) } codec.Close() } @@ -387,7 +406,7 @@ func (server *serverType) accept(lis net.Listener) { if err != nil { log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit? } - go ServeConn(conn) + go ServeConn(conn, &ClientInfo{conn.LocalAddr().String(), conn.RemoteAddr().String()}) } } @@ -419,14 +438,14 @@ type ServerCodec interface { // The caller typically invokes ServeConn in a go statement. // ServeConn uses the gob wire format (see package gob) on the // connection. To use an alternate codec, use ServeCodec. -func ServeConn(conn io.ReadWriteCloser) { - ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)}) +func ServeConn(conn io.ReadWriteCloser, ci *ClientInfo) { + ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)}, ci) } // ServeCodec is like ServeConn but uses the specified codec to // decode requests and encode responses. -func ServeCodec(codec ServerCodec) { - server.input(codec) +func ServeCodec(codec ServerCodec, ci *ClientInfo) { + server.input(codec, ci) } // Accept accepts connections on the listener and serves requests @@ -452,7 +471,11 @@ func serveHTTP(c *http.Conn, req *http.Request) { return } io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") - ServeConn(conn) + ci := &ClientInfo{ + LocalAddr: conn.(net.Conn).LocalAddr().String(), + RemoteAddr: c.RemoteAddr, + } + ServeConn(conn, ci) } // HandleHTTP registers an HTTP handler for RPC messages. |
