summaryrefslogtreecommitdiff
path: root/src/pkg/netchan
diff options
context:
space:
mode:
authorOndřej Surý <ondrej@sury.org>2011-01-17 12:40:45 +0100
committerOndřej Surý <ondrej@sury.org>2011-01-17 12:40:45 +0100
commit3e45412327a2654a77944249962b3652e6142299 (patch)
treebc3bf69452afa055423cbe0c5cfa8ca357df6ccf /src/pkg/netchan
parentc533680039762cacbc37db8dc7eed074c3e497be (diff)
downloadgolang-upstream/2011.01.12.tar.gz
Imported Upstream version 2011.01.12upstream/2011.01.12
Diffstat (limited to 'src/pkg/netchan')
-rw-r--r--src/pkg/netchan/Makefile2
-rw-r--r--src/pkg/netchan/common.go121
-rw-r--r--src/pkg/netchan/export.go238
-rw-r--r--src/pkg/netchan/import.go128
-rw-r--r--src/pkg/netchan/netchan_test.go329
5 files changed, 688 insertions, 130 deletions
diff --git a/src/pkg/netchan/Makefile b/src/pkg/netchan/Makefile
index a8a5c6a3c..9b9fdcf59 100644
--- a/src/pkg/netchan/Makefile
+++ b/src/pkg/netchan/Makefile
@@ -2,7 +2,7 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
-include ../../Make.$(GOARCH)
+include ../../Make.inc
TARG=netchan
GOFILES=\
diff --git a/src/pkg/netchan/common.go b/src/pkg/netchan/common.go
index 624397ef4..bde3087a5 100644
--- a/src/pkg/netchan/common.go
+++ b/src/pkg/netchan/common.go
@@ -10,6 +10,7 @@ import (
"os"
"reflect"
"sync"
+ "time"
)
// The direction of a connection from the client's perspective.
@@ -20,31 +21,65 @@ const (
Send
)
+func (dir Dir) String() string {
+ switch dir {
+ case Recv:
+ return "Recv"
+ case Send:
+ return "Send"
+ }
+ return "???"
+}
+
// Payload types
const (
payRequest = iota // request structure follows
payError // error structure follows
payData // user payload follows
+ payAck // acknowledgement; no payload
+ payClosed // channel is now closed
)
// A header is sent as a prefix to every transmission. It will be followed by
// a request structure, an error structure, or an arbitrary user payload structure.
type header struct {
- name string
- payloadType int
+ Name string
+ PayloadType int
+ SeqNum int64
}
// Sent with a header once per channel from importer to exporter to report
// that it wants to bind to a channel with the specified direction for count
-// messages. If count is zero, it means unlimited.
+// messages. If count is -1, it means unlimited.
type request struct {
- count int
- dir Dir
+ Count int64
+ Dir Dir
}
// Sent with a header to report an error.
type error struct {
- error string
+ Error string
+}
+
+// Used to unify management of acknowledgements for import and export.
+type unackedCounter interface {
+ unackedCount() int64
+ ack() int64
+ seq() int64
+}
+
+// A channel and its direction.
+type chanDir struct {
+ ch *reflect.ChanValue
+ dir Dir
+}
+
+// clientSet contains the objects and methods needed for tracking
+// clients of an exporter and draining outstanding messages.
+type clientSet struct {
+ mu sync.Mutex // protects access to channel and client maps
+ chans map[string]*chanDir
+ clients map[unackedCounter]bool
}
// Mutex-protected encoder and decoder pair.
@@ -76,13 +111,81 @@ func (ed *encDec) decode(value reflect.Value) os.Error {
// Encode a header and payload onto the connection.
func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.Error {
ed.encLock.Lock()
- hdr.payloadType = payloadType
+ hdr.PayloadType = payloadType
err := ed.enc.Encode(hdr)
if err == nil {
- err = ed.enc.Encode(payload)
- } else {
+ if payload != nil {
+ err = ed.enc.Encode(payload)
+ }
+ }
+ if err != nil {
// TODO: tear down connection if there is an error?
}
ed.encLock.Unlock()
return err
}
+
+// See the comment for Exporter.Drain.
+func (cs *clientSet) drain(timeout int64) os.Error {
+ startTime := time.Nanoseconds()
+ for {
+ pending := false
+ cs.mu.Lock()
+ // Any messages waiting for a client?
+ for _, chDir := range cs.chans {
+ if chDir.ch.Len() > 0 {
+ pending = true
+ }
+ }
+ // Any unacknowledged messages?
+ for client := range cs.clients {
+ n := client.unackedCount()
+ if n > 0 { // Check for > rather than != just to be safe.
+ pending = true
+ break
+ }
+ }
+ cs.mu.Unlock()
+ if !pending {
+ break
+ }
+ if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+ return os.ErrorString("timeout")
+ }
+ time.Sleep(100 * 1e6) // 100 milliseconds
+ }
+ return nil
+}
+
+// See the comment for Exporter.Sync.
+func (cs *clientSet) sync(timeout int64) os.Error {
+ startTime := time.Nanoseconds()
+ // seq remembers the clients and their seqNum at point of entry.
+ seq := make(map[unackedCounter]int64)
+ for client := range cs.clients {
+ seq[client] = client.seq()
+ }
+ for {
+ pending := false
+ cs.mu.Lock()
+ // Any unacknowledged messages? Look only at clients that existed
+ // when we started and are still in this client set.
+ for client := range seq {
+ if _, ok := cs.clients[client]; ok {
+ if client.ack() < seq[client] {
+ pending = true
+ break
+ }
+ }
+ }
+ cs.mu.Unlock()
+ if !pending {
+ break
+ }
+ if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+ return os.ErrorString("timeout")
+ }
+ time.Sleep(100 * 1e6) // 100 milliseconds
+ }
+ return nil
+}
diff --git a/src/pkg/netchan/export.go b/src/pkg/netchan/export.go
index a16714ba2..9ad388c18 100644
--- a/src/pkg/netchan/export.go
+++ b/src/pkg/netchan/export.go
@@ -19,7 +19,7 @@
*/
package netchan
-// BUG: can't use range clause to receive when using ImportNValues with N non-zero.
+// BUG: can't use range clause to receive when using ImportNValues to limit the count.
import (
"log"
@@ -31,73 +31,69 @@ import (
// Export
-// A channel and its associated information: a direction plus
-// a handy marshaling place for its data.
-type exportChan struct {
- ch *reflect.ChanValue
- dir Dir
+// expLog is a logging convenience function. The first argument must be a string.
+func expLog(args ...interface{}) {
+ args[0] = "netchan export: " + args[0].(string)
+ log.Print(args...)
}
// An Exporter allows a set of channels to be published on a single
// network port. A single machine may have multiple Exporters
// but they must use different ports.
type Exporter struct {
+ *clientSet
listener net.Listener
- chanLock sync.Mutex // protects access to channel map
- chans map[string]*exportChan
}
type expClient struct {
*encDec
- exp *Exporter
+ exp *Exporter
+ mu sync.Mutex // protects remaining fields
+ errored bool // client has been sent an error
+ seqNum int64 // sequences messages sent to client; has value of highest sent
+ ackNum int64 // highest sequence number acknowledged
+ seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
}
func newClient(exp *Exporter, conn net.Conn) *expClient {
client := new(expClient)
client.exp = exp
client.encDec = newEncDec(conn)
+ client.seqNum = 0
+ client.ackNum = 0
return client
}
-// Wait for incoming connections, start a new runner for each
-func (exp *Exporter) listen() {
- for {
- conn, err := exp.listener.Accept()
- if err != nil {
- log.Stderr("exporter.listen:", err)
- break
- }
- client := newClient(exp, conn)
- go client.run()
- }
-}
-
func (client *expClient) sendError(hdr *header, err string) {
error := &error{err}
- log.Stderr("export:", error.error)
+ expLog("sending error to client:", error.Error)
client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
+ client.mu.Lock()
+ client.errored = true
+ client.mu.Unlock()
}
-func (client *expClient) getChan(hdr *header, dir Dir) *exportChan {
+func (client *expClient) getChan(hdr *header, dir Dir) *chanDir {
exp := client.exp
- exp.chanLock.Lock()
- ech, ok := exp.chans[hdr.name]
- exp.chanLock.Unlock()
+ exp.mu.Lock()
+ ech, ok := exp.chans[hdr.Name]
+ exp.mu.Unlock()
if !ok {
- client.sendError(hdr, "no such channel: "+hdr.name)
+ client.sendError(hdr, "no such channel: "+hdr.Name)
return nil
}
if ech.dir != dir {
- client.sendError(hdr, "wrong direction for channel: "+hdr.name)
+ client.sendError(hdr, "wrong direction for channel: "+hdr.Name)
return nil
}
return ech
}
-// Manage sends and receives for a single client. For each (client Recv) request,
-// this will launch a serveRecv goroutine to deliver the data for that channel,
-// while (client Send) requests are handled as data arrives from the client.
+// The function run manages sends and receives for a single client. For each
+// (client Recv) request, this will launch a serveRecv goroutine to deliver
+// the data for that channel, while (client Send) requests are handled as
+// data arrives from the client.
func (client *expClient) run() {
hdr := new(header)
hdrValue := reflect.NewValue(hdr)
@@ -105,40 +101,58 @@ func (client *expClient) run() {
reqValue := reflect.NewValue(req)
error := new(error)
for {
+ *hdr = header{}
if err := client.decode(hdrValue); err != nil {
- log.Stderr("error decoding client header:", err)
- // TODO: tear down connection
- return
+ expLog("error decoding client header:", err)
+ break
}
- switch hdr.payloadType {
+ switch hdr.PayloadType {
case payRequest:
+ *req = request{}
if err := client.decode(reqValue); err != nil {
- log.Stderr("error decoding client request:", err)
- // TODO: tear down connection
- return
+ expLog("error decoding client request:", err)
+ break
}
- switch req.dir {
+ switch req.Dir {
case Recv:
- go client.serveRecv(*hdr, req.count)
+ go client.serveRecv(*hdr, req.Count)
case Send:
// Request to send is clear as a matter of protocol
// but not actually used by the implementation.
// The actual sends will have payload type payData.
// TODO: manage the count?
default:
- error.error = "export request: can't handle channel direction"
- log.Stderr(error.error, req.dir)
+ error.Error = "request: can't handle channel direction"
+ expLog(error.Error, req.Dir)
client.encode(hdr, payError, error)
}
case payData:
client.serveSend(*hdr)
+ case payClosed:
+ client.serveClosed(*hdr)
+ case payAck:
+ client.mu.Lock()
+ if client.ackNum != hdr.SeqNum-1 {
+ // Since the sequence number is incremented and the message is sent
+ // in a single instance of locking client.mu, the messages are guaranteed
+ // to be sent in order. Therefore receipt of acknowledgement N means
+ // all messages <=N have been seen by the recipient. We check anyway.
+ expLog("sequence out of order:", client.ackNum, hdr.SeqNum)
+ }
+ if client.ackNum < hdr.SeqNum { // If there has been an error, don't back up the count.
+ client.ackNum = hdr.SeqNum
+ }
+ client.mu.Unlock()
+ default:
+ log.Exit("netchan export: unknown payload type", hdr.PayloadType)
}
}
+ client.exp.delClient(client)
}
// Send all the data on a single channel to a client asking for a Recv.
// The header is passed by value to avoid issues of overwriting.
-func (client *expClient) serveRecv(hdr header, count int) {
+func (client *expClient) serveRecv(hdr header, count int64) {
ech := client.getChan(&hdr, Send)
if ech == nil {
return
@@ -146,16 +160,30 @@ func (client *expClient) serveRecv(hdr header, count int) {
for {
val := ech.ch.Recv()
if ech.ch.Closed() {
- client.sendError(&hdr, os.EOF.String())
+ if err := client.encode(&hdr, payClosed, nil); err != nil {
+ expLog("error encoding server closed message:", err)
+ }
break
}
- if err := client.encode(&hdr, payData, val.Interface()); err != nil {
- log.Stderr("error encoding client response:", err)
+ // We hold the lock during transmission to guarantee messages are
+ // sent in sequence number order. Also, we increment first so the
+ // value of client.seqNum is the value of the highest used sequence
+ // number, not one beyond.
+ client.mu.Lock()
+ client.seqNum++
+ hdr.SeqNum = client.seqNum
+ client.seqLock.Lock() // guarantee ordering of messages
+ client.mu.Unlock()
+ err := client.encode(&hdr, payData, val.Interface())
+ client.seqLock.Unlock()
+ if err != nil {
+ expLog("error encoding client response:", err)
client.sendError(&hdr, err.String())
break
}
- if count > 0 {
- if count--; count == 0 {
+ // Negative count means run forever.
+ if count >= 0 {
+ if count--; count <= 0 {
break
}
}
@@ -172,11 +200,54 @@ func (client *expClient) serveSend(hdr header) {
// Create a new value for each received item.
val := reflect.MakeZero(ech.ch.Type().(*reflect.ChanType).Elem())
if err := client.decode(val); err != nil {
- log.Stderr("exporter value decode:", err)
+ expLog("value decode:", err)
return
}
ech.ch.Send(val)
- // TODO count
+}
+
+// Report that client has closed the channel that is sending to us.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveClosed(hdr header) {
+ ech := client.getChan(&hdr, Recv)
+ if ech == nil {
+ return
+ }
+ ech.ch.Close()
+}
+
+func (client *expClient) unackedCount() int64 {
+ client.mu.Lock()
+ n := client.seqNum - client.ackNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) seq() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) ack() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+// Wait for incoming connections, start a new runner for each
+func (exp *Exporter) listen() {
+ for {
+ conn, err := exp.listener.Accept()
+ if err != nil {
+ expLog("listen:", err)
+ break
+ }
+ client := exp.addClient(conn)
+ go client.run()
+ }
}
// NewExporter creates a new Exporter to export channels
@@ -188,12 +259,53 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) {
}
e := &Exporter{
listener: listener,
- chans: make(map[string]*exportChan),
+ clientSet: &clientSet{
+ chans: make(map[string]*chanDir),
+ clients: make(map[unackedCounter]bool),
+ },
}
go e.listen()
return e, nil
}
+// addClient creates a new expClient and records its existence
+func (exp *Exporter) addClient(conn net.Conn) *expClient {
+ client := newClient(exp, conn)
+ exp.mu.Lock()
+ exp.clients[client] = true
+ exp.mu.Unlock()
+ return client
+}
+
+// delClient forgets the client existed
+func (exp *Exporter) delClient(client *expClient) {
+ exp.mu.Lock()
+ exp.clients[client] = false, false
+ exp.mu.Unlock()
+}
+
+// Drain waits until all messages sent from this exporter/importer, including
+// those not yet sent to any client and possibly including those sent while
+// Drain was executing, have been received by the importer. In short, it
+// waits until all the exporter's messages have been received by a client.
+// If the timeout (measured in nanoseconds) is positive and Drain takes
+// longer than that to complete, an error is returned.
+func (exp *Exporter) Drain(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.drain(timeout)
+}
+
+// Sync waits until all clients of the exporter have received the messages
+// that were sent at the time Sync was invoked. Unlike Drain, it does not
+// wait for messages sent while it is running or messages that have not been
+// dispatched to any client. If the timeout (measured in nanoseconds) is
+// positive and Sync takes longer than that to complete, an error is
+// returned.
+func (exp *Exporter) Sync(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.sync(timeout)
+}
+
// Addr returns the Exporter's local network address.
func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() }
@@ -229,12 +341,28 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
if err != nil {
return err
}
- exp.chanLock.Lock()
- defer exp.chanLock.Unlock()
+ exp.mu.Lock()
+ defer exp.mu.Unlock()
_, present := exp.chans[name]
if present {
return os.ErrorString("channel name already being exported:" + name)
}
- exp.chans[name] = &exportChan{ch, dir}
+ exp.chans[name] = &chanDir{ch, dir}
+ return nil
+}
+
+// Hangup disassociates the named channel from the Exporter and closes
+// the channel. Messages in flight for the channel may be dropped.
+func (exp *Exporter) Hangup(name string) os.Error {
+ exp.mu.Lock()
+ chDir, ok := exp.chans[name]
+ if ok {
+ exp.chans[name] = nil, false
+ }
+ exp.mu.Unlock()
+ if !ok {
+ return os.ErrorString("netchan export: hangup: no such channel: " + name)
+ }
+ chDir.ch.Close()
return nil
}
diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go
index 244a83c5b..baae367a0 100644
--- a/src/pkg/netchan/import.go
+++ b/src/pkg/netchan/import.go
@@ -14,11 +14,10 @@ import (
// Import
-// A channel and its associated information: a template value and direction,
-// plus a handy marshaling place for its data.
-type importChan struct {
- ch *reflect.ChanValue
- dir Dir
+// impLog is a logging convenience function. The first argument must be a string.
+func impLog(args ...interface{}) {
+ args[0] = "netchan import: " + args[0].(string)
+ log.Print(args...)
}
// An Importer allows a set of channels to be imported from a single
@@ -28,7 +27,8 @@ type Importer struct {
*encDec
conn net.Conn
chanLock sync.Mutex // protects access to channel map
- chans map[string]*importChan
+ chans map[string]*chanDir
+ errors chan os.Error
}
// NewImporter creates a new Importer object to import channels
@@ -43,7 +43,8 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
imp := new(Importer)
imp.encDec = newEncDec(conn)
imp.conn = conn
- imp.chans = make(map[string]*importChan)
+ imp.chans = make(map[string]*chanDir)
+ imp.errors = make(chan os.Error, 10)
go imp.run()
return imp, nil
}
@@ -67,61 +68,92 @@ func (imp *Importer) run() {
// Loop on responses; requests are sent by ImportNValues()
hdr := new(header)
hdrValue := reflect.NewValue(hdr)
+ ackHdr := new(header)
err := new(error)
errValue := reflect.NewValue(err)
for {
+ *hdr = header{}
if e := imp.decode(hdrValue); e != nil {
- log.Stderr("importer header:", e)
+ impLog("header:", e)
imp.shutdown()
return
}
- switch hdr.payloadType {
+ switch hdr.PayloadType {
case payData:
// done lower in loop
case payError:
if e := imp.decode(errValue); e != nil {
- log.Stderr("importer error:", e)
+ impLog("error:", e)
return
}
- if err.error != "" {
- log.Stderr("importer response error:", err.error)
- imp.shutdown()
- return
+ if err.Error != "" {
+ impLog("response error:", err.Error)
+ if sent := imp.errors <- os.ErrorString(err.Error); !sent {
+ imp.shutdown()
+ return
+ }
+ continue // errors are not acknowledged.
}
+ case payClosed:
+ ich := imp.getChan(hdr.Name)
+ if ich != nil {
+ ich.ch.Close()
+ }
+ continue // closes are not acknowledged.
default:
- log.Stderr("unexpected payload type:", hdr.payloadType)
+ impLog("unexpected payload type:", hdr.PayloadType)
return
}
- imp.chanLock.Lock()
- ich, ok := imp.chans[hdr.name]
- imp.chanLock.Unlock()
- if !ok {
- log.Stderr("unknown name in request:", hdr.name)
- return
+ ich := imp.getChan(hdr.Name)
+ if ich == nil {
+ continue
}
if ich.dir != Recv {
- log.Stderr("cannot happen: receive from non-Recv channel")
+ impLog("cannot happen: receive from non-Recv channel")
return
}
+ // Acknowledge receipt
+ ackHdr.Name = hdr.Name
+ ackHdr.SeqNum = hdr.SeqNum
+ imp.encode(ackHdr, payAck, nil)
// Create a new value for each received item.
value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem())
if e := imp.decode(value); e != nil {
- log.Stderr("importer value decode:", e)
+ impLog("importer value decode:", e)
return
}
ich.ch.Send(value)
}
}
+func (imp *Importer) getChan(name string) *chanDir {
+ imp.chanLock.Lock()
+ ich := imp.chans[name]
+ imp.chanLock.Unlock()
+ if ich == nil {
+ impLog("unknown name in netchan request:", name)
+ return nil
+ }
+ return ich
+}
+
+// Errors returns a channel from which transmission and protocol errors
+// can be read. Clients of the importer are not required to read the error
+// channel for correct execution. However, if too many errors occur
+// without being read from the error channel, the importer will shut down.
+func (imp *Importer) Errors() chan os.Error {
+ return imp.errors
+}
+
// Import imports a channel of the given type and specified direction.
-// It is equivalent to ImportNValues with a count of 0, meaning unbounded.
+// It is equivalent to ImportNValues with a count of -1, meaning unbounded.
func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error {
- return imp.ImportNValues(name, chT, dir, 0)
+ return imp.ImportNValues(name, chT, dir, -1)
}
// ImportNValues imports a channel of the given type and specified direction
// and then receives or transmits up to n values on that channel. A value of
-// n==0 implies an unbounded number of values. The channel to be bound to
+// n==-1 implies an unbounded number of values. The channel to be bound to
// the remote site's channel is provided in the call and may be of arbitrary
// channel type.
// Despite the literal signature, the effective signature is
@@ -130,7 +162,7 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error {
// imp, err := NewImporter("tcp", "netchanserver.mydomain.com:1234")
// if err != nil { log.Exit(err) }
// ch := make(chan myType)
-// err := imp.ImportNValues("name", ch, Recv, 1)
+// err = imp.ImportNValues("name", ch, Recv, 1)
// if err != nil { log.Exit(err) }
// fmt.Printf("%+v\n", <-ch)
func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) os.Error {
@@ -144,24 +176,26 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
if present {
return os.ErrorString("channel name already being imported:" + name)
}
- imp.chans[name] = &importChan{ch, dir}
+ imp.chans[name] = &chanDir{ch, dir}
// Tell the other side about this channel.
- hdr := new(header)
- hdr.name = name
- hdr.payloadType = payRequest
- req := new(request)
- req.dir = dir
- req.count = n
- if err := imp.encode(hdr, payRequest, req); err != nil {
- log.Stderr("importer request encode:", err)
+ hdr := &header{Name: name}
+ req := &request{Count: int64(n), Dir: dir}
+ if err = imp.encode(hdr, payRequest, req); err != nil {
+ impLog("request encode:", err)
return err
}
if dir == Send {
go func() {
- for i := 0; n == 0 || i < n; i++ {
+ for i := 0; n == -1 || i < n; i++ {
val := ch.Recv()
- if err := imp.encode(hdr, payData, val.Interface()); err != nil {
- log.Stderr("error encoding client response:", err)
+ if ch.Closed() {
+ if err = imp.encode(hdr, payClosed, nil); err != nil {
+ impLog("error encoding client closed message:", err)
+ }
+ return
+ }
+ if err = imp.encode(hdr, payData, val.Interface()); err != nil {
+ impLog("error encoding client send:", err)
return
}
}
@@ -169,3 +203,19 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
}
return nil
}
+
+// Hangup disassociates the named channel from the Importer and closes
+// the channel. Messages in flight for the channel may be dropped.
+func (imp *Importer) Hangup(name string) os.Error {
+ imp.chanLock.Lock()
+ chDir, ok := imp.chans[name]
+ if ok {
+ imp.chans[name] = nil, false
+ }
+ imp.chanLock.Unlock()
+ if !ok {
+ return os.ErrorString("netchan import: hangup: no such channel: " + name)
+ }
+ chDir.ch.Close()
+ return nil
+}
diff --git a/src/pkg/netchan/netchan_test.go b/src/pkg/netchan/netchan_test.go
index 6b5c67c3c..766c4c474 100644
--- a/src/pkg/netchan/netchan_test.go
+++ b/src/pkg/netchan/netchan_test.go
@@ -4,38 +4,67 @@
package netchan
-import "testing"
+import (
+ "strings"
+ "testing"
+ "time"
+)
const count = 10 // number of items in most tests
const closeCount = 5 // number of items when sender closes early
+const base = 23
+
func exportSend(exp *Exporter, n int, t *testing.T) {
ch := make(chan int)
err := exp.Export("exportedSend", ch, Send)
if err != nil {
t.Fatal("exportSend:", err)
}
- for i := 0; i < n; i++ {
- ch <- 23+i
- }
- close(ch)
+ go func() {
+ for i := 0; i < n; i++ {
+ ch <- base+i
+ }
+ close(ch)
+ }()
}
-func exportReceive(exp *Exporter, t *testing.T) {
+func exportReceive(exp *Exporter, t *testing.T, expDone chan bool) {
ch := make(chan int)
err := exp.Export("exportedRecv", ch, Recv)
+ expDone <- true
if err != nil {
t.Fatal("exportReceive:", err)
}
for i := 0; i < count; i++ {
v := <-ch
- if v != 45+i {
- t.Errorf("export Receive: bad value: expected 4%d; got %d", 45+i, v)
+ if closed(ch) {
+ if i != closeCount {
+ t.Errorf("exportReceive expected close at %d; got one at %d", closeCount, i)
+ }
+ break
+ }
+ if v != base+i {
+ t.Errorf("export Receive: bad value: expected %d+%d=%d; got %d", base, i, base+i, v)
}
}
}
-func importReceive(imp *Importer, t *testing.T) {
+func importSend(imp *Importer, n int, t *testing.T) {
+ ch := make(chan int)
+ err := imp.ImportNValues("exportedRecv", ch, Send, count)
+ if err != nil {
+ t.Fatal("importSend:", err)
+ }
+ go func() {
+ for i := 0; i < n; i++ {
+ ch <- base+i
+ }
+ close(ch)
+ }()
+}
+
+func importReceive(imp *Importer, t *testing.T, done chan bool) {
ch := make(chan int)
err := imp.ImportNValues("exportedSend", ch, Recv, count)
if err != nil {
@@ -45,28 +74,33 @@ func importReceive(imp *Importer, t *testing.T) {
v := <-ch
if closed(ch) {
if i != closeCount {
- t.Errorf("expected close at %d; got one at %d\n", closeCount, i)
+ t.Errorf("importReceive expected close at %d; got one at %d", closeCount, i)
}
break
}
if v != 23+i {
- t.Errorf("importReceive: bad value: expected %d; got %+d", 23+i, v)
+ t.Errorf("importReceive: bad value: expected %d+%d=%d; got %+d", base, i, base+i, v)
}
}
+ if done != nil {
+ done <- true
+ }
}
-func importSend(imp *Importer, t *testing.T) {
- ch := make(chan int)
- err := imp.ImportNValues("exportedRecv", ch, Send, count)
+func TestExportSendImportReceive(t *testing.T) {
+ exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
- t.Fatal("importSend:", err)
+ t.Fatal("new exporter:", err)
}
- for i := 0; i < count; i++ {
- ch <- 45+i
+ imp, err := NewImporter("tcp", exp.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
}
+ exportSend(exp, count, t)
+ importReceive(imp, t, nil)
}
-func TestExportSendImportReceive(t *testing.T) {
+func TestExportReceiveImportSend(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
@@ -75,11 +109,18 @@ func TestExportSendImportReceive(t *testing.T) {
if err != nil {
t.Fatal("new importer:", err)
}
- go exportSend(exp, count, t)
- importReceive(imp, t)
+ expDone := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ exportReceive(exp, t, expDone)
+ done <- true
+ }()
+ <-expDone
+ importSend(imp, count, t)
+ <-done
}
-func TestExportReceiveImportSend(t *testing.T) {
+func TestClosingExportSendImportReceive(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
@@ -88,11 +129,92 @@ func TestExportReceiveImportSend(t *testing.T) {
if err != nil {
t.Fatal("new importer:", err)
}
- go importSend(imp, t)
- exportReceive(exp, t)
+ exportSend(exp, closeCount, t)
+ importReceive(imp, t, nil)
}
-func TestClosingExportSendImportReceive(t *testing.T) {
+func TestClosingImportSendExportReceive(t *testing.T) {
+ exp, err := NewExporter("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("new exporter:", err)
+ }
+ imp, err := NewImporter("tcp", exp.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
+ }
+ expDone := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ exportReceive(exp, t, expDone)
+ done <- true
+ }()
+ <-expDone
+ importSend(imp, closeCount, t)
+ <-done
+}
+
+func TestErrorForIllegalChannel(t *testing.T) {
+ exp, err := NewExporter("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("new exporter:", err)
+ }
+ imp, err := NewImporter("tcp", exp.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
+ }
+ // Now export a channel.
+ ch := make(chan int, 1)
+ err = exp.Export("aChannel", ch, Send)
+ if err != nil {
+ t.Fatal("export:", err)
+ }
+ ch <- 1234
+ close(ch)
+ // Now try to import a different channel.
+ ch = make(chan int)
+ err = imp.Import("notAChannel", ch, Recv)
+ if err != nil {
+ t.Fatal("import:", err)
+ }
+ // Expect an error now. Start a timeout.
+ timeout := make(chan bool, 1) // buffered so closure will not hang around.
+ go func() {
+ time.Sleep(10e9) // very long, to give even really slow machines a chance.
+ timeout <- true
+ }()
+ select {
+ case err = <-imp.Errors():
+ if strings.Index(err.String(), "no such channel") < 0 {
+ t.Error("wrong error for nonexistent channel:", err)
+ }
+ case <-timeout:
+ t.Error("import of nonexistent channel did not receive an error")
+ }
+}
+
+// Not a great test but it does at least invoke Drain.
+func TestExportDrain(t *testing.T) {
+ exp, err := NewExporter("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("new exporter:", err)
+ }
+ imp, err := NewImporter("tcp", exp.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
+ }
+ done := make(chan bool)
+ go func() {
+ exportSend(exp, closeCount, t)
+ done <- true
+ }()
+ <-done
+ go importReceive(imp, t, done)
+ exp.Drain(0)
+ <-done
+}
+
+// Not a great test but it does at least invoke Sync.
+func TestExportSync(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
@@ -101,6 +223,161 @@ func TestClosingExportSendImportReceive(t *testing.T) {
if err != nil {
t.Fatal("new importer:", err)
}
- go exportSend(exp, closeCount, t)
- importReceive(imp, t)
+ done := make(chan bool)
+ exportSend(exp, closeCount, t)
+ go importReceive(imp, t, done)
+ exp.Sync(0)
+ <-done
+}
+
+// Test hanging up the send side of an export.
+// TODO: test hanging up the receive side of an export.
+func TestExportHangup(t *testing.T) {
+ exp, err := NewExporter("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("new exporter:", err)
+ }
+ imp, err := NewImporter("tcp", exp.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
+ }
+ ech := make(chan int)
+ err = exp.Export("exportedSend", ech, Send)
+ if err != nil {
+ t.Fatal("export:", err)
+ }
+ // Prepare to receive two values. We'll actually deliver only one.
+ ich := make(chan int)
+ err = imp.ImportNValues("exportedSend", ich, Recv, 2)
+ if err != nil {
+ t.Fatal("import exportedSend:", err)
+ }
+ // Send one value, receive it.
+ const Value = 1234
+ ech <- Value
+ v := <-ich
+ if v != Value {
+ t.Fatal("expected", Value, "got", v)
+ }
+ // Now hang up the channel. Importer should see it close.
+ exp.Hangup("exportedSend")
+ v = <-ich
+ if !closed(ich) {
+ t.Fatal("expected channel to be closed; got value", v)
+ }
+}
+
+// Test hanging up the send side of an import.
+// TODO: test hanging up the receive side of an import.
+func TestImportHangup(t *testing.T) {
+ exp, err := NewExporter("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("new exporter:", err)
+ }
+ imp, err := NewImporter("tcp", exp.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
+ }
+ ech := make(chan int)
+ err = exp.Export("exportedRecv", ech, Recv)
+ if err != nil {
+ t.Fatal("export:", err)
+ }
+ // Prepare to Send two values. We'll actually deliver only one.
+ ich := make(chan int)
+ err = imp.ImportNValues("exportedRecv", ich, Send, 2)
+ if err != nil {
+ t.Fatal("import exportedRecv:", err)
+ }
+ // Send one value, receive it.
+ const Value = 1234
+ ich <- Value
+ v := <-ech
+ if v != Value {
+ t.Fatal("expected", Value, "got", v)
+ }
+ // Now hang up the channel. Exporter should see it close.
+ imp.Hangup("exportedRecv")
+ v = <-ech
+ if !closed(ech) {
+ t.Fatal("expected channel to be closed; got value", v)
+ }
+}
+
+// This test cross-connects a pair of exporter/importer pairs.
+type value struct {
+ i int
+ source string
+}
+
+func TestCrossConnect(t *testing.T) {
+ e1, err := NewExporter("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("new exporter:", err)
+ }
+ i1, err := NewImporter("tcp", e1.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
+ }
+
+ e2, err := NewExporter("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("new exporter:", err)
+ }
+ i2, err := NewImporter("tcp", e2.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
+ }
+
+ go crossExport(e1, e2, t)
+ crossImport(i1, i2, t)
+}
+
+// Export side of cross-traffic.
+func crossExport(e1, e2 *Exporter, t *testing.T) {
+ s := make(chan value)
+ err := e1.Export("exportedSend", s, Send)
+ if err != nil {
+ t.Fatal("exportSend:", err)
+ }
+
+ r := make(chan value)
+ err = e2.Export("exportedReceive", r, Recv)
+ if err != nil {
+ t.Fatal("exportReceive:", err)
+ }
+
+ crossLoop("export", s, r, t)
+}
+
+// Import side of cross-traffic.
+func crossImport(i1, i2 *Importer, t *testing.T) {
+ s := make(chan value)
+ err := i2.Import("exportedReceive", s, Send)
+ if err != nil {
+ t.Fatal("import of exportedReceive:", err)
+ }
+
+ r := make(chan value)
+ err = i1.Import("exportedSend", r, Recv)
+ if err != nil {
+ t.Fatal("import of exported Send:", err)
+ }
+
+ crossLoop("import", s, r, t)
+}
+
+// Cross-traffic: send and receive 'count' numbers.
+func crossLoop(name string, s, r chan value, t *testing.T) {
+ for si, ri := 0, 0; si < count && ri < count; {
+ select {
+ case s <- value{si, name}:
+ si++
+ case v := <-r:
+ if v.i != ri {
+ t.Errorf("loop: bad value: expected %d, hello; got %+v", ri, v)
+ }
+ ri++
+ }
+ }
}