diff options
author | Ondřej Surý <ondrej@sury.org> | 2011-01-17 12:40:45 +0100 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2011-01-17 12:40:45 +0100 |
commit | 3e45412327a2654a77944249962b3652e6142299 (patch) | |
tree | bc3bf69452afa055423cbe0c5cfa8ca357df6ccf /src/pkg/netchan | |
parent | c533680039762cacbc37db8dc7eed074c3e497be (diff) | |
download | golang-upstream/2011.01.12.tar.gz |
Imported Upstream version 2011.01.12upstream/2011.01.12
Diffstat (limited to 'src/pkg/netchan')
-rw-r--r-- | src/pkg/netchan/Makefile | 2 | ||||
-rw-r--r-- | src/pkg/netchan/common.go | 121 | ||||
-rw-r--r-- | src/pkg/netchan/export.go | 238 | ||||
-rw-r--r-- | src/pkg/netchan/import.go | 128 | ||||
-rw-r--r-- | src/pkg/netchan/netchan_test.go | 329 |
5 files changed, 688 insertions, 130 deletions
diff --git a/src/pkg/netchan/Makefile b/src/pkg/netchan/Makefile index a8a5c6a3c..9b9fdcf59 100644 --- a/src/pkg/netchan/Makefile +++ b/src/pkg/netchan/Makefile @@ -2,7 +2,7 @@ # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file. -include ../../Make.$(GOARCH) +include ../../Make.inc TARG=netchan GOFILES=\ diff --git a/src/pkg/netchan/common.go b/src/pkg/netchan/common.go index 624397ef4..bde3087a5 100644 --- a/src/pkg/netchan/common.go +++ b/src/pkg/netchan/common.go @@ -10,6 +10,7 @@ import ( "os" "reflect" "sync" + "time" ) // The direction of a connection from the client's perspective. @@ -20,31 +21,65 @@ const ( Send ) +func (dir Dir) String() string { + switch dir { + case Recv: + return "Recv" + case Send: + return "Send" + } + return "???" +} + // Payload types const ( payRequest = iota // request structure follows payError // error structure follows payData // user payload follows + payAck // acknowledgement; no payload + payClosed // channel is now closed ) // A header is sent as a prefix to every transmission. It will be followed by // a request structure, an error structure, or an arbitrary user payload structure. type header struct { - name string - payloadType int + Name string + PayloadType int + SeqNum int64 } // Sent with a header once per channel from importer to exporter to report // that it wants to bind to a channel with the specified direction for count -// messages. If count is zero, it means unlimited. +// messages. If count is -1, it means unlimited. type request struct { - count int - dir Dir + Count int64 + Dir Dir } // Sent with a header to report an error. type error struct { - error string + Error string +} + +// Used to unify management of acknowledgements for import and export. +type unackedCounter interface { + unackedCount() int64 + ack() int64 + seq() int64 +} + +// A channel and its direction. +type chanDir struct { + ch *reflect.ChanValue + dir Dir +} + +// clientSet contains the objects and methods needed for tracking +// clients of an exporter and draining outstanding messages. +type clientSet struct { + mu sync.Mutex // protects access to channel and client maps + chans map[string]*chanDir + clients map[unackedCounter]bool } // Mutex-protected encoder and decoder pair. @@ -76,13 +111,81 @@ func (ed *encDec) decode(value reflect.Value) os.Error { // Encode a header and payload onto the connection. func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.Error { ed.encLock.Lock() - hdr.payloadType = payloadType + hdr.PayloadType = payloadType err := ed.enc.Encode(hdr) if err == nil { - err = ed.enc.Encode(payload) - } else { + if payload != nil { + err = ed.enc.Encode(payload) + } + } + if err != nil { // TODO: tear down connection if there is an error? } ed.encLock.Unlock() return err } + +// See the comment for Exporter.Drain. +func (cs *clientSet) drain(timeout int64) os.Error { + startTime := time.Nanoseconds() + for { + pending := false + cs.mu.Lock() + // Any messages waiting for a client? + for _, chDir := range cs.chans { + if chDir.ch.Len() > 0 { + pending = true + } + } + // Any unacknowledged messages? + for client := range cs.clients { + n := client.unackedCount() + if n > 0 { // Check for > rather than != just to be safe. + pending = true + break + } + } + cs.mu.Unlock() + if !pending { + break + } + if timeout > 0 && time.Nanoseconds()-startTime >= timeout { + return os.ErrorString("timeout") + } + time.Sleep(100 * 1e6) // 100 milliseconds + } + return nil +} + +// See the comment for Exporter.Sync. +func (cs *clientSet) sync(timeout int64) os.Error { + startTime := time.Nanoseconds() + // seq remembers the clients and their seqNum at point of entry. + seq := make(map[unackedCounter]int64) + for client := range cs.clients { + seq[client] = client.seq() + } + for { + pending := false + cs.mu.Lock() + // Any unacknowledged messages? Look only at clients that existed + // when we started and are still in this client set. + for client := range seq { + if _, ok := cs.clients[client]; ok { + if client.ack() < seq[client] { + pending = true + break + } + } + } + cs.mu.Unlock() + if !pending { + break + } + if timeout > 0 && time.Nanoseconds()-startTime >= timeout { + return os.ErrorString("timeout") + } + time.Sleep(100 * 1e6) // 100 milliseconds + } + return nil +} diff --git a/src/pkg/netchan/export.go b/src/pkg/netchan/export.go index a16714ba2..9ad388c18 100644 --- a/src/pkg/netchan/export.go +++ b/src/pkg/netchan/export.go @@ -19,7 +19,7 @@ */ package netchan -// BUG: can't use range clause to receive when using ImportNValues with N non-zero. +// BUG: can't use range clause to receive when using ImportNValues to limit the count. import ( "log" @@ -31,73 +31,69 @@ import ( // Export -// A channel and its associated information: a direction plus -// a handy marshaling place for its data. -type exportChan struct { - ch *reflect.ChanValue - dir Dir +// expLog is a logging convenience function. The first argument must be a string. +func expLog(args ...interface{}) { + args[0] = "netchan export: " + args[0].(string) + log.Print(args...) } // An Exporter allows a set of channels to be published on a single // network port. A single machine may have multiple Exporters // but they must use different ports. type Exporter struct { + *clientSet listener net.Listener - chanLock sync.Mutex // protects access to channel map - chans map[string]*exportChan } type expClient struct { *encDec - exp *Exporter + exp *Exporter + mu sync.Mutex // protects remaining fields + errored bool // client has been sent an error + seqNum int64 // sequences messages sent to client; has value of highest sent + ackNum int64 // highest sequence number acknowledged + seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu } func newClient(exp *Exporter, conn net.Conn) *expClient { client := new(expClient) client.exp = exp client.encDec = newEncDec(conn) + client.seqNum = 0 + client.ackNum = 0 return client } -// Wait for incoming connections, start a new runner for each -func (exp *Exporter) listen() { - for { - conn, err := exp.listener.Accept() - if err != nil { - log.Stderr("exporter.listen:", err) - break - } - client := newClient(exp, conn) - go client.run() - } -} - func (client *expClient) sendError(hdr *header, err string) { error := &error{err} - log.Stderr("export:", error.error) + expLog("sending error to client:", error.Error) client.encode(hdr, payError, error) // ignore any encode error, hope client gets it + client.mu.Lock() + client.errored = true + client.mu.Unlock() } -func (client *expClient) getChan(hdr *header, dir Dir) *exportChan { +func (client *expClient) getChan(hdr *header, dir Dir) *chanDir { exp := client.exp - exp.chanLock.Lock() - ech, ok := exp.chans[hdr.name] - exp.chanLock.Unlock() + exp.mu.Lock() + ech, ok := exp.chans[hdr.Name] + exp.mu.Unlock() if !ok { - client.sendError(hdr, "no such channel: "+hdr.name) + client.sendError(hdr, "no such channel: "+hdr.Name) return nil } if ech.dir != dir { - client.sendError(hdr, "wrong direction for channel: "+hdr.name) + client.sendError(hdr, "wrong direction for channel: "+hdr.Name) return nil } return ech } -// Manage sends and receives for a single client. For each (client Recv) request, -// this will launch a serveRecv goroutine to deliver the data for that channel, -// while (client Send) requests are handled as data arrives from the client. +// The function run manages sends and receives for a single client. For each +// (client Recv) request, this will launch a serveRecv goroutine to deliver +// the data for that channel, while (client Send) requests are handled as +// data arrives from the client. func (client *expClient) run() { hdr := new(header) hdrValue := reflect.NewValue(hdr) @@ -105,40 +101,58 @@ func (client *expClient) run() { reqValue := reflect.NewValue(req) error := new(error) for { + *hdr = header{} if err := client.decode(hdrValue); err != nil { - log.Stderr("error decoding client header:", err) - // TODO: tear down connection - return + expLog("error decoding client header:", err) + break } - switch hdr.payloadType { + switch hdr.PayloadType { case payRequest: + *req = request{} if err := client.decode(reqValue); err != nil { - log.Stderr("error decoding client request:", err) - // TODO: tear down connection - return + expLog("error decoding client request:", err) + break } - switch req.dir { + switch req.Dir { case Recv: - go client.serveRecv(*hdr, req.count) + go client.serveRecv(*hdr, req.Count) case Send: // Request to send is clear as a matter of protocol // but not actually used by the implementation. // The actual sends will have payload type payData. // TODO: manage the count? default: - error.error = "export request: can't handle channel direction" - log.Stderr(error.error, req.dir) + error.Error = "request: can't handle channel direction" + expLog(error.Error, req.Dir) client.encode(hdr, payError, error) } case payData: client.serveSend(*hdr) + case payClosed: + client.serveClosed(*hdr) + case payAck: + client.mu.Lock() + if client.ackNum != hdr.SeqNum-1 { + // Since the sequence number is incremented and the message is sent + // in a single instance of locking client.mu, the messages are guaranteed + // to be sent in order. Therefore receipt of acknowledgement N means + // all messages <=N have been seen by the recipient. We check anyway. + expLog("sequence out of order:", client.ackNum, hdr.SeqNum) + } + if client.ackNum < hdr.SeqNum { // If there has been an error, don't back up the count. + client.ackNum = hdr.SeqNum + } + client.mu.Unlock() + default: + log.Exit("netchan export: unknown payload type", hdr.PayloadType) } } + client.exp.delClient(client) } // Send all the data on a single channel to a client asking for a Recv. // The header is passed by value to avoid issues of overwriting. -func (client *expClient) serveRecv(hdr header, count int) { +func (client *expClient) serveRecv(hdr header, count int64) { ech := client.getChan(&hdr, Send) if ech == nil { return @@ -146,16 +160,30 @@ func (client *expClient) serveRecv(hdr header, count int) { for { val := ech.ch.Recv() if ech.ch.Closed() { - client.sendError(&hdr, os.EOF.String()) + if err := client.encode(&hdr, payClosed, nil); err != nil { + expLog("error encoding server closed message:", err) + } break } - if err := client.encode(&hdr, payData, val.Interface()); err != nil { - log.Stderr("error encoding client response:", err) + // We hold the lock during transmission to guarantee messages are + // sent in sequence number order. Also, we increment first so the + // value of client.seqNum is the value of the highest used sequence + // number, not one beyond. + client.mu.Lock() + client.seqNum++ + hdr.SeqNum = client.seqNum + client.seqLock.Lock() // guarantee ordering of messages + client.mu.Unlock() + err := client.encode(&hdr, payData, val.Interface()) + client.seqLock.Unlock() + if err != nil { + expLog("error encoding client response:", err) client.sendError(&hdr, err.String()) break } - if count > 0 { - if count--; count == 0 { + // Negative count means run forever. + if count >= 0 { + if count--; count <= 0 { break } } @@ -172,11 +200,54 @@ func (client *expClient) serveSend(hdr header) { // Create a new value for each received item. val := reflect.MakeZero(ech.ch.Type().(*reflect.ChanType).Elem()) if err := client.decode(val); err != nil { - log.Stderr("exporter value decode:", err) + expLog("value decode:", err) return } ech.ch.Send(val) - // TODO count +} + +// Report that client has closed the channel that is sending to us. +// The header is passed by value to avoid issues of overwriting. +func (client *expClient) serveClosed(hdr header) { + ech := client.getChan(&hdr, Recv) + if ech == nil { + return + } + ech.ch.Close() +} + +func (client *expClient) unackedCount() int64 { + client.mu.Lock() + n := client.seqNum - client.ackNum + client.mu.Unlock() + return n +} + +func (client *expClient) seq() int64 { + client.mu.Lock() + n := client.seqNum + client.mu.Unlock() + return n +} + +func (client *expClient) ack() int64 { + client.mu.Lock() + n := client.seqNum + client.mu.Unlock() + return n +} + +// Wait for incoming connections, start a new runner for each +func (exp *Exporter) listen() { + for { + conn, err := exp.listener.Accept() + if err != nil { + expLog("listen:", err) + break + } + client := exp.addClient(conn) + go client.run() + } } // NewExporter creates a new Exporter to export channels @@ -188,12 +259,53 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) { } e := &Exporter{ listener: listener, - chans: make(map[string]*exportChan), + clientSet: &clientSet{ + chans: make(map[string]*chanDir), + clients: make(map[unackedCounter]bool), + }, } go e.listen() return e, nil } +// addClient creates a new expClient and records its existence +func (exp *Exporter) addClient(conn net.Conn) *expClient { + client := newClient(exp, conn) + exp.mu.Lock() + exp.clients[client] = true + exp.mu.Unlock() + return client +} + +// delClient forgets the client existed +func (exp *Exporter) delClient(client *expClient) { + exp.mu.Lock() + exp.clients[client] = false, false + exp.mu.Unlock() +} + +// Drain waits until all messages sent from this exporter/importer, including +// those not yet sent to any client and possibly including those sent while +// Drain was executing, have been received by the importer. In short, it +// waits until all the exporter's messages have been received by a client. +// If the timeout (measured in nanoseconds) is positive and Drain takes +// longer than that to complete, an error is returned. +func (exp *Exporter) Drain(timeout int64) os.Error { + // This wrapper function is here so the method's comment will appear in godoc. + return exp.clientSet.drain(timeout) +} + +// Sync waits until all clients of the exporter have received the messages +// that were sent at the time Sync was invoked. Unlike Drain, it does not +// wait for messages sent while it is running or messages that have not been +// dispatched to any client. If the timeout (measured in nanoseconds) is +// positive and Sync takes longer than that to complete, an error is +// returned. +func (exp *Exporter) Sync(timeout int64) os.Error { + // This wrapper function is here so the method's comment will appear in godoc. + return exp.clientSet.sync(timeout) +} + // Addr returns the Exporter's local network address. func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() } @@ -229,12 +341,28 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { if err != nil { return err } - exp.chanLock.Lock() - defer exp.chanLock.Unlock() + exp.mu.Lock() + defer exp.mu.Unlock() _, present := exp.chans[name] if present { return os.ErrorString("channel name already being exported:" + name) } - exp.chans[name] = &exportChan{ch, dir} + exp.chans[name] = &chanDir{ch, dir} + return nil +} + +// Hangup disassociates the named channel from the Exporter and closes +// the channel. Messages in flight for the channel may be dropped. +func (exp *Exporter) Hangup(name string) os.Error { + exp.mu.Lock() + chDir, ok := exp.chans[name] + if ok { + exp.chans[name] = nil, false + } + exp.mu.Unlock() + if !ok { + return os.ErrorString("netchan export: hangup: no such channel: " + name) + } + chDir.ch.Close() return nil } diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go index 244a83c5b..baae367a0 100644 --- a/src/pkg/netchan/import.go +++ b/src/pkg/netchan/import.go @@ -14,11 +14,10 @@ import ( // Import -// A channel and its associated information: a template value and direction, -// plus a handy marshaling place for its data. -type importChan struct { - ch *reflect.ChanValue - dir Dir +// impLog is a logging convenience function. The first argument must be a string. +func impLog(args ...interface{}) { + args[0] = "netchan import: " + args[0].(string) + log.Print(args...) } // An Importer allows a set of channels to be imported from a single @@ -28,7 +27,8 @@ type Importer struct { *encDec conn net.Conn chanLock sync.Mutex // protects access to channel map - chans map[string]*importChan + chans map[string]*chanDir + errors chan os.Error } // NewImporter creates a new Importer object to import channels @@ -43,7 +43,8 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) { imp := new(Importer) imp.encDec = newEncDec(conn) imp.conn = conn - imp.chans = make(map[string]*importChan) + imp.chans = make(map[string]*chanDir) + imp.errors = make(chan os.Error, 10) go imp.run() return imp, nil } @@ -67,61 +68,92 @@ func (imp *Importer) run() { // Loop on responses; requests are sent by ImportNValues() hdr := new(header) hdrValue := reflect.NewValue(hdr) + ackHdr := new(header) err := new(error) errValue := reflect.NewValue(err) for { + *hdr = header{} if e := imp.decode(hdrValue); e != nil { - log.Stderr("importer header:", e) + impLog("header:", e) imp.shutdown() return } - switch hdr.payloadType { + switch hdr.PayloadType { case payData: // done lower in loop case payError: if e := imp.decode(errValue); e != nil { - log.Stderr("importer error:", e) + impLog("error:", e) return } - if err.error != "" { - log.Stderr("importer response error:", err.error) - imp.shutdown() - return + if err.Error != "" { + impLog("response error:", err.Error) + if sent := imp.errors <- os.ErrorString(err.Error); !sent { + imp.shutdown() + return + } + continue // errors are not acknowledged. } + case payClosed: + ich := imp.getChan(hdr.Name) + if ich != nil { + ich.ch.Close() + } + continue // closes are not acknowledged. default: - log.Stderr("unexpected payload type:", hdr.payloadType) + impLog("unexpected payload type:", hdr.PayloadType) return } - imp.chanLock.Lock() - ich, ok := imp.chans[hdr.name] - imp.chanLock.Unlock() - if !ok { - log.Stderr("unknown name in request:", hdr.name) - return + ich := imp.getChan(hdr.Name) + if ich == nil { + continue } if ich.dir != Recv { - log.Stderr("cannot happen: receive from non-Recv channel") + impLog("cannot happen: receive from non-Recv channel") return } + // Acknowledge receipt + ackHdr.Name = hdr.Name + ackHdr.SeqNum = hdr.SeqNum + imp.encode(ackHdr, payAck, nil) // Create a new value for each received item. value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem()) if e := imp.decode(value); e != nil { - log.Stderr("importer value decode:", e) + impLog("importer value decode:", e) return } ich.ch.Send(value) } } +func (imp *Importer) getChan(name string) *chanDir { + imp.chanLock.Lock() + ich := imp.chans[name] + imp.chanLock.Unlock() + if ich == nil { + impLog("unknown name in netchan request:", name) + return nil + } + return ich +} + +// Errors returns a channel from which transmission and protocol errors +// can be read. Clients of the importer are not required to read the error +// channel for correct execution. However, if too many errors occur +// without being read from the error channel, the importer will shut down. +func (imp *Importer) Errors() chan os.Error { + return imp.errors +} + // Import imports a channel of the given type and specified direction. -// It is equivalent to ImportNValues with a count of 0, meaning unbounded. +// It is equivalent to ImportNValues with a count of -1, meaning unbounded. func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error { - return imp.ImportNValues(name, chT, dir, 0) + return imp.ImportNValues(name, chT, dir, -1) } // ImportNValues imports a channel of the given type and specified direction // and then receives or transmits up to n values on that channel. A value of -// n==0 implies an unbounded number of values. The channel to be bound to +// n==-1 implies an unbounded number of values. The channel to be bound to // the remote site's channel is provided in the call and may be of arbitrary // channel type. // Despite the literal signature, the effective signature is @@ -130,7 +162,7 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error { // imp, err := NewImporter("tcp", "netchanserver.mydomain.com:1234") // if err != nil { log.Exit(err) } // ch := make(chan myType) -// err := imp.ImportNValues("name", ch, Recv, 1) +// err = imp.ImportNValues("name", ch, Recv, 1) // if err != nil { log.Exit(err) } // fmt.Printf("%+v\n", <-ch) func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) os.Error { @@ -144,24 +176,26 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) if present { return os.ErrorString("channel name already being imported:" + name) } - imp.chans[name] = &importChan{ch, dir} + imp.chans[name] = &chanDir{ch, dir} // Tell the other side about this channel. - hdr := new(header) - hdr.name = name - hdr.payloadType = payRequest - req := new(request) - req.dir = dir - req.count = n - if err := imp.encode(hdr, payRequest, req); err != nil { - log.Stderr("importer request encode:", err) + hdr := &header{Name: name} + req := &request{Count: int64(n), Dir: dir} + if err = imp.encode(hdr, payRequest, req); err != nil { + impLog("request encode:", err) return err } if dir == Send { go func() { - for i := 0; n == 0 || i < n; i++ { + for i := 0; n == -1 || i < n; i++ { val := ch.Recv() - if err := imp.encode(hdr, payData, val.Interface()); err != nil { - log.Stderr("error encoding client response:", err) + if ch.Closed() { + if err = imp.encode(hdr, payClosed, nil); err != nil { + impLog("error encoding client closed message:", err) + } + return + } + if err = imp.encode(hdr, payData, val.Interface()); err != nil { + impLog("error encoding client send:", err) return } } @@ -169,3 +203,19 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) } return nil } + +// Hangup disassociates the named channel from the Importer and closes +// the channel. Messages in flight for the channel may be dropped. +func (imp *Importer) Hangup(name string) os.Error { + imp.chanLock.Lock() + chDir, ok := imp.chans[name] + if ok { + imp.chans[name] = nil, false + } + imp.chanLock.Unlock() + if !ok { + return os.ErrorString("netchan import: hangup: no such channel: " + name) + } + chDir.ch.Close() + return nil +} diff --git a/src/pkg/netchan/netchan_test.go b/src/pkg/netchan/netchan_test.go index 6b5c67c3c..766c4c474 100644 --- a/src/pkg/netchan/netchan_test.go +++ b/src/pkg/netchan/netchan_test.go @@ -4,38 +4,67 @@ package netchan -import "testing" +import ( + "strings" + "testing" + "time" +) const count = 10 // number of items in most tests const closeCount = 5 // number of items when sender closes early +const base = 23 + func exportSend(exp *Exporter, n int, t *testing.T) { ch := make(chan int) err := exp.Export("exportedSend", ch, Send) if err != nil { t.Fatal("exportSend:", err) } - for i := 0; i < n; i++ { - ch <- 23+i - } - close(ch) + go func() { + for i := 0; i < n; i++ { + ch <- base+i + } + close(ch) + }() } -func exportReceive(exp *Exporter, t *testing.T) { +func exportReceive(exp *Exporter, t *testing.T, expDone chan bool) { ch := make(chan int) err := exp.Export("exportedRecv", ch, Recv) + expDone <- true if err != nil { t.Fatal("exportReceive:", err) } for i := 0; i < count; i++ { v := <-ch - if v != 45+i { - t.Errorf("export Receive: bad value: expected 4%d; got %d", 45+i, v) + if closed(ch) { + if i != closeCount { + t.Errorf("exportReceive expected close at %d; got one at %d", closeCount, i) + } + break + } + if v != base+i { + t.Errorf("export Receive: bad value: expected %d+%d=%d; got %d", base, i, base+i, v) } } } -func importReceive(imp *Importer, t *testing.T) { +func importSend(imp *Importer, n int, t *testing.T) { + ch := make(chan int) + err := imp.ImportNValues("exportedRecv", ch, Send, count) + if err != nil { + t.Fatal("importSend:", err) + } + go func() { + for i := 0; i < n; i++ { + ch <- base+i + } + close(ch) + }() +} + +func importReceive(imp *Importer, t *testing.T, done chan bool) { ch := make(chan int) err := imp.ImportNValues("exportedSend", ch, Recv, count) if err != nil { @@ -45,28 +74,33 @@ func importReceive(imp *Importer, t *testing.T) { v := <-ch if closed(ch) { if i != closeCount { - t.Errorf("expected close at %d; got one at %d\n", closeCount, i) + t.Errorf("importReceive expected close at %d; got one at %d", closeCount, i) } break } if v != 23+i { - t.Errorf("importReceive: bad value: expected %d; got %+d", 23+i, v) + t.Errorf("importReceive: bad value: expected %d+%d=%d; got %+d", base, i, base+i, v) } } + if done != nil { + done <- true + } } -func importSend(imp *Importer, t *testing.T) { - ch := make(chan int) - err := imp.ImportNValues("exportedRecv", ch, Send, count) +func TestExportSendImportReceive(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") if err != nil { - t.Fatal("importSend:", err) + t.Fatal("new exporter:", err) } - for i := 0; i < count; i++ { - ch <- 45+i + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) } + exportSend(exp, count, t) + importReceive(imp, t, nil) } -func TestExportSendImportReceive(t *testing.T) { +func TestExportReceiveImportSend(t *testing.T) { exp, err := NewExporter("tcp", "127.0.0.1:0") if err != nil { t.Fatal("new exporter:", err) @@ -75,11 +109,18 @@ func TestExportSendImportReceive(t *testing.T) { if err != nil { t.Fatal("new importer:", err) } - go exportSend(exp, count, t) - importReceive(imp, t) + expDone := make(chan bool) + done := make(chan bool) + go func() { + exportReceive(exp, t, expDone) + done <- true + }() + <-expDone + importSend(imp, count, t) + <-done } -func TestExportReceiveImportSend(t *testing.T) { +func TestClosingExportSendImportReceive(t *testing.T) { exp, err := NewExporter("tcp", "127.0.0.1:0") if err != nil { t.Fatal("new exporter:", err) @@ -88,11 +129,92 @@ func TestExportReceiveImportSend(t *testing.T) { if err != nil { t.Fatal("new importer:", err) } - go importSend(imp, t) - exportReceive(exp, t) + exportSend(exp, closeCount, t) + importReceive(imp, t, nil) } -func TestClosingExportSendImportReceive(t *testing.T) { +func TestClosingImportSendExportReceive(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + expDone := make(chan bool) + done := make(chan bool) + go func() { + exportReceive(exp, t, expDone) + done <- true + }() + <-expDone + importSend(imp, closeCount, t) + <-done +} + +func TestErrorForIllegalChannel(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + // Now export a channel. + ch := make(chan int, 1) + err = exp.Export("aChannel", ch, Send) + if err != nil { + t.Fatal("export:", err) + } + ch <- 1234 + close(ch) + // Now try to import a different channel. + ch = make(chan int) + err = imp.Import("notAChannel", ch, Recv) + if err != nil { + t.Fatal("import:", err) + } + // Expect an error now. Start a timeout. + timeout := make(chan bool, 1) // buffered so closure will not hang around. + go func() { + time.Sleep(10e9) // very long, to give even really slow machines a chance. + timeout <- true + }() + select { + case err = <-imp.Errors(): + if strings.Index(err.String(), "no such channel") < 0 { + t.Error("wrong error for nonexistent channel:", err) + } + case <-timeout: + t.Error("import of nonexistent channel did not receive an error") + } +} + +// Not a great test but it does at least invoke Drain. +func TestExportDrain(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + done := make(chan bool) + go func() { + exportSend(exp, closeCount, t) + done <- true + }() + <-done + go importReceive(imp, t, done) + exp.Drain(0) + <-done +} + +// Not a great test but it does at least invoke Sync. +func TestExportSync(t *testing.T) { exp, err := NewExporter("tcp", "127.0.0.1:0") if err != nil { t.Fatal("new exporter:", err) @@ -101,6 +223,161 @@ func TestClosingExportSendImportReceive(t *testing.T) { if err != nil { t.Fatal("new importer:", err) } - go exportSend(exp, closeCount, t) - importReceive(imp, t) + done := make(chan bool) + exportSend(exp, closeCount, t) + go importReceive(imp, t, done) + exp.Sync(0) + <-done +} + +// Test hanging up the send side of an export. +// TODO: test hanging up the receive side of an export. +func TestExportHangup(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + ech := make(chan int) + err = exp.Export("exportedSend", ech, Send) + if err != nil { + t.Fatal("export:", err) + } + // Prepare to receive two values. We'll actually deliver only one. + ich := make(chan int) + err = imp.ImportNValues("exportedSend", ich, Recv, 2) + if err != nil { + t.Fatal("import exportedSend:", err) + } + // Send one value, receive it. + const Value = 1234 + ech <- Value + v := <-ich + if v != Value { + t.Fatal("expected", Value, "got", v) + } + // Now hang up the channel. Importer should see it close. + exp.Hangup("exportedSend") + v = <-ich + if !closed(ich) { + t.Fatal("expected channel to be closed; got value", v) + } +} + +// Test hanging up the send side of an import. +// TODO: test hanging up the receive side of an import. +func TestImportHangup(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + ech := make(chan int) + err = exp.Export("exportedRecv", ech, Recv) + if err != nil { + t.Fatal("export:", err) + } + // Prepare to Send two values. We'll actually deliver only one. + ich := make(chan int) + err = imp.ImportNValues("exportedRecv", ich, Send, 2) + if err != nil { + t.Fatal("import exportedRecv:", err) + } + // Send one value, receive it. + const Value = 1234 + ich <- Value + v := <-ech + if v != Value { + t.Fatal("expected", Value, "got", v) + } + // Now hang up the channel. Exporter should see it close. + imp.Hangup("exportedRecv") + v = <-ech + if !closed(ech) { + t.Fatal("expected channel to be closed; got value", v) + } +} + +// This test cross-connects a pair of exporter/importer pairs. +type value struct { + i int + source string +} + +func TestCrossConnect(t *testing.T) { + e1, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + i1, err := NewImporter("tcp", e1.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + + e2, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + i2, err := NewImporter("tcp", e2.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + + go crossExport(e1, e2, t) + crossImport(i1, i2, t) +} + +// Export side of cross-traffic. +func crossExport(e1, e2 *Exporter, t *testing.T) { + s := make(chan value) + err := e1.Export("exportedSend", s, Send) + if err != nil { + t.Fatal("exportSend:", err) + } + + r := make(chan value) + err = e2.Export("exportedReceive", r, Recv) + if err != nil { + t.Fatal("exportReceive:", err) + } + + crossLoop("export", s, r, t) +} + +// Import side of cross-traffic. +func crossImport(i1, i2 *Importer, t *testing.T) { + s := make(chan value) + err := i2.Import("exportedReceive", s, Send) + if err != nil { + t.Fatal("import of exportedReceive:", err) + } + + r := make(chan value) + err = i1.Import("exportedSend", r, Recv) + if err != nil { + t.Fatal("import of exported Send:", err) + } + + crossLoop("import", s, r, t) +} + +// Cross-traffic: send and receive 'count' numbers. +func crossLoop(name string, s, r chan value, t *testing.T) { + for si, ri := 0, 0; si < count && ri < count; { + select { + case s <- value{si, name}: + si++ + case v := <-r: + if v.i != ri { + t.Errorf("loop: bad value: expected %d, hello; got %+v", ri, v) + } + ri++ + } + } } |