diff options
Diffstat (limited to 'src/pkg/rpc/server.go')
-rw-r--r-- | src/pkg/rpc/server.go | 61 |
1 files changed, 37 insertions, 24 deletions
diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index af31a65cc..acadeec37 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The rpc package provides access to the exported methods of an object across a + Package rpc provides access to the exported methods of an object across a network or other I/O connection. A server registers an object, making it visible as a service with the name of the type of the object. After registration, exported methods of the object will be accessible remotely. A server may register multiple @@ -13,8 +13,11 @@ Only methods that satisfy these criteria will be made available for remote access; other methods will be ignored: - - the method receiver and name are exported, that is, begin with an upper case letter. - - the method has two arguments, both pointers to exported types. + - the method name is exported, that is, begins with an upper case letter. + - the method receiver is exported or local (defined in the package + registering the service). + - the method has two arguments, both exported or local types. + - the method's second argument is a pointer. - the method has return type os.Error. The method's first argument represents the arguments provided by the caller; the @@ -133,7 +136,7 @@ const ( // Precompute the reflect type for os.Error. Can't use os.Error directly // because Typeof takes an empty interface value. This is annoying. var unusedError *os.Error -var typeOfOsError = reflect.Typeof(unusedError).Elem() +var typeOfOsError = reflect.TypeOf(unusedError).Elem() type methodType struct { sync.Mutex // protects counters @@ -193,6 +196,14 @@ func isExported(name string) bool { return unicode.IsUpper(rune) } +// Is this type exported or local to this package? +func isExportedOrLocalType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t.PkgPath() == "" || isExported(t.Name()) +} + // Register publishes in the server the set of methods of the // receiver value that satisfy the following conditions: // - exported method @@ -219,8 +230,8 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E server.serviceMap = make(map[string]*service) } s := new(service) - s.typ = reflect.Typeof(rcvr) - s.rcvr = reflect.NewValue(rcvr) + s.typ = reflect.TypeOf(rcvr) + s.rcvr = reflect.ValueOf(rcvr) sname := reflect.Indirect(s.rcvr).Type().Name() if useName { sname = name @@ -252,23 +263,20 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) continue } + // First arg need not be a pointer. argType := mtype.In(1) - ok := argType.Kind() == reflect.Ptr - if !ok { - log.Println(mname, "arg type not a pointer:", mtype.In(1)) + if !isExportedOrLocalType(argType) { + log.Println(mname, "argument type not exported or local:", argType) continue } + // Second arg must be a pointer. replyType := mtype.In(2) if replyType.Kind() != reflect.Ptr { - log.Println(mname, "reply type not a pointer:", mtype.In(2)) - continue - } - if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) { - log.Println(mname, "argument type not exported:", argType) + log.Println("method", mname, "reply type not a pointer:", replyType) continue } - if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) { - log.Println(mname, "reply type not exported:", replyType) + if !isExportedOrLocalType(replyType) { + log.Println("method", mname, "reply type not exported or local:", replyType) continue } // Method needs one out: os.Error. @@ -297,12 +305,6 @@ type InvalidRequest struct{} var invalidRequest = InvalidRequest{} -func _new(t reflect.Type) reflect.Value { - v := reflect.Zero(t) - v.Set(reflect.Zero(t.Elem()).Addr()) - return v -} - func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { resp := server.getResponse() // Encode the response header @@ -411,8 +413,16 @@ func (server *Server) ServeCodec(codec ServerCodec) { } // Decode the argument value. - argv := _new(mtype.ArgType) - replyv := _new(mtype.ReplyType) + var argv reflect.Value + argIsValue := false // if true, need to indirect before calling. + if mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(mtype.ArgType.Elem()) + } else { + argv = reflect.New(mtype.ArgType) + argIsValue = true + } + // argv guaranteed to be a pointer now. + replyv := reflect.New(mtype.ReplyType.Elem()) err = codec.ReadRequestBody(argv.Interface()) if err != nil { if err == os.EOF || err == io.ErrUnexpectedEOF { @@ -424,6 +434,9 @@ func (server *Server) ServeCodec(codec ServerCodec) { server.sendResponse(sending, req, replyv.Interface(), codec, err.String()) continue } + if argIsValue { + argv = argv.Elem() + } go service.call(server, sending, mtype, req, argv, replyv, codec) } codec.Close() |