summaryrefslogtreecommitdiff
path: root/src/pkg/netchan
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/netchan')
-rw-r--r--src/pkg/netchan/import.go45
-rw-r--r--src/pkg/netchan/netchan_test.go10
2 files changed, 53 insertions, 2 deletions
diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go
index 0a700ca2b..7d96228c4 100644
--- a/src/pkg/netchan/import.go
+++ b/src/pkg/netchan/import.go
@@ -11,6 +11,7 @@ import (
"os"
"reflect"
"sync"
+ "time"
)
// Import
@@ -31,6 +32,9 @@ type Importer struct {
chans map[int]*netChan
errors chan os.Error
maxId int
+ mu sync.Mutex // protects remaining fields
+ unacked int64 // number of unacknowledged sends.
+ seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
}
// NewImporter creates a new Importer object to import a set of channels
@@ -42,6 +46,7 @@ func NewImporter(conn io.ReadWriter) *Importer {
imp.chans = make(map[int]*netChan)
imp.names = make(map[string]*netChan)
imp.errors = make(chan os.Error, 10)
+ imp.unacked = 0
go imp.run()
return imp
}
@@ -80,8 +85,10 @@ func (imp *Importer) run() {
for {
*hdr = header{}
if e := imp.decode(hdrValue); e != nil {
- impLog("header:", e)
- imp.shutdown()
+ if e != os.EOF {
+ impLog("header:", e)
+ imp.shutdown()
+ }
return
}
switch hdr.PayloadType {
@@ -114,6 +121,9 @@ func (imp *Importer) run() {
nch := imp.getChan(hdr.Id, true)
if nch != nil {
nch.acked()
+ imp.mu.Lock()
+ imp.unacked--
+ imp.mu.Unlock()
}
continue
default:
@@ -220,10 +230,17 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size,
}
return
}
+ // We hold the lock during transmission to guarantee messages are
+ // sent in order.
+ imp.mu.Lock()
+ imp.unacked++
+ imp.seqLock.Lock()
+ imp.mu.Unlock()
if err = imp.encode(hdr, payData, val.Interface()); err != nil {
impLog("error encoding client send:", err)
return
}
+ imp.seqLock.Unlock()
}
}()
}
@@ -244,3 +261,27 @@ func (imp *Importer) Hangup(name string) os.Error {
nc.close()
return nil
}
+
+func (imp *Importer) unackedCount() int64 {
+ imp.mu.Lock()
+ n := imp.unacked
+ imp.mu.Unlock()
+ return n
+}
+
+// Drain waits until all messages sent from this exporter/importer, including
+// those not yet sent to any server and possibly including those sent while
+// Drain was executing, have been received by the exporter. In short, it
+// waits until all the importer's messages have been received.
+// If the timeout (measured in nanoseconds) is positive and Drain takes
+// longer than that to complete, an error is returned.
+func (imp *Importer) Drain(timeout int64) os.Error {
+ startTime := time.Nanoseconds()
+ for imp.unackedCount() > 0 {
+ if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+ return os.ErrorString("timeout")
+ }
+ time.Sleep(100 * 1e6)
+ }
+ return nil
+}
diff --git a/src/pkg/netchan/netchan_test.go b/src/pkg/netchan/netchan_test.go
index fd4d8f780..8c0f9a6e4 100644
--- a/src/pkg/netchan/netchan_test.go
+++ b/src/pkg/netchan/netchan_test.go
@@ -178,6 +178,16 @@ func TestExportDrain(t *testing.T) {
<-done
}
+// Not a great test but it does at least invoke Drain.
+func TestImportDrain(t *testing.T) {
+ exp, imp := pair(t)
+ expDone := make(chan bool)
+ go exportReceive(exp, t, expDone)
+ <-expDone
+ importSend(imp, closeCount, t, nil)
+ imp.Drain(0)
+}
+
// Not a great test but it does at least invoke Sync.
func TestExportSync(t *testing.T) {
exp, imp := pair(t)