summaryrefslogtreecommitdiff
path: root/src/pkg/netchan/common.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/netchan/common.go')
-rw-r--r--src/pkg/netchan/common.go121
1 files changed, 112 insertions, 9 deletions
diff --git a/src/pkg/netchan/common.go b/src/pkg/netchan/common.go
index 624397ef4..bde3087a5 100644
--- a/src/pkg/netchan/common.go
+++ b/src/pkg/netchan/common.go
@@ -10,6 +10,7 @@ import (
"os"
"reflect"
"sync"
+ "time"
)
// The direction of a connection from the client's perspective.
@@ -20,31 +21,65 @@ const (
Send
)
+func (dir Dir) String() string {
+ switch dir {
+ case Recv:
+ return "Recv"
+ case Send:
+ return "Send"
+ }
+ return "???"
+}
+
// Payload types
const (
payRequest = iota // request structure follows
payError // error structure follows
payData // user payload follows
+ payAck // acknowledgement; no payload
+ payClosed // channel is now closed
)
// A header is sent as a prefix to every transmission. It will be followed by
// a request structure, an error structure, or an arbitrary user payload structure.
type header struct {
- name string
- payloadType int
+ Name string
+ PayloadType int
+ SeqNum int64
}
// Sent with a header once per channel from importer to exporter to report
// that it wants to bind to a channel with the specified direction for count
-// messages. If count is zero, it means unlimited.
+// messages. If count is -1, it means unlimited.
type request struct {
- count int
- dir Dir
+ Count int64
+ Dir Dir
}
// Sent with a header to report an error.
type error struct {
- error string
+ Error string
+}
+
+// Used to unify management of acknowledgements for import and export.
+type unackedCounter interface {
+ unackedCount() int64
+ ack() int64
+ seq() int64
+}
+
+// A channel and its direction.
+type chanDir struct {
+ ch *reflect.ChanValue
+ dir Dir
+}
+
+// clientSet contains the objects and methods needed for tracking
+// clients of an exporter and draining outstanding messages.
+type clientSet struct {
+ mu sync.Mutex // protects access to channel and client maps
+ chans map[string]*chanDir
+ clients map[unackedCounter]bool
}
// Mutex-protected encoder and decoder pair.
@@ -76,13 +111,81 @@ func (ed *encDec) decode(value reflect.Value) os.Error {
// Encode a header and payload onto the connection.
func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.Error {
ed.encLock.Lock()
- hdr.payloadType = payloadType
+ hdr.PayloadType = payloadType
err := ed.enc.Encode(hdr)
if err == nil {
- err = ed.enc.Encode(payload)
- } else {
+ if payload != nil {
+ err = ed.enc.Encode(payload)
+ }
+ }
+ if err != nil {
// TODO: tear down connection if there is an error?
}
ed.encLock.Unlock()
return err
}
+
+// See the comment for Exporter.Drain.
+func (cs *clientSet) drain(timeout int64) os.Error {
+ startTime := time.Nanoseconds()
+ for {
+ pending := false
+ cs.mu.Lock()
+ // Any messages waiting for a client?
+ for _, chDir := range cs.chans {
+ if chDir.ch.Len() > 0 {
+ pending = true
+ }
+ }
+ // Any unacknowledged messages?
+ for client := range cs.clients {
+ n := client.unackedCount()
+ if n > 0 { // Check for > rather than != just to be safe.
+ pending = true
+ break
+ }
+ }
+ cs.mu.Unlock()
+ if !pending {
+ break
+ }
+ if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+ return os.ErrorString("timeout")
+ }
+ time.Sleep(100 * 1e6) // 100 milliseconds
+ }
+ return nil
+}
+
+// See the comment for Exporter.Sync.
+func (cs *clientSet) sync(timeout int64) os.Error {
+ startTime := time.Nanoseconds()
+ // seq remembers the clients and their seqNum at point of entry.
+ seq := make(map[unackedCounter]int64)
+ for client := range cs.clients {
+ seq[client] = client.seq()
+ }
+ for {
+ pending := false
+ cs.mu.Lock()
+ // Any unacknowledged messages? Look only at clients that existed
+ // when we started and are still in this client set.
+ for client := range seq {
+ if _, ok := cs.clients[client]; ok {
+ if client.ack() < seq[client] {
+ pending = true
+ break
+ }
+ }
+ }
+ cs.mu.Unlock()
+ if !pending {
+ break
+ }
+ if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+ return os.ErrorString("timeout")
+ }
+ time.Sleep(100 * 1e6) // 100 milliseconds
+ }
+ return nil
+}