summaryrefslogtreecommitdiff
path: root/src/pkg/database
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/database')
-rw-r--r--src/pkg/database/sql/convert.go62
-rw-r--r--src/pkg/database/sql/convert_test.go33
-rw-r--r--src/pkg/database/sql/driver/driver.go4
-rw-r--r--src/pkg/database/sql/fakedb_test.go61
-rw-r--r--src/pkg/database/sql/sql.go624
-rw-r--r--src/pkg/database/sql/sql_test.go416
6 files changed, 983 insertions, 217 deletions
diff --git a/src/pkg/database/sql/convert.go b/src/pkg/database/sql/convert.go
index 853a7826c..c04adde1f 100644
--- a/src/pkg/database/sql/convert.go
+++ b/src/pkg/database/sql/convert.go
@@ -19,9 +19,13 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript
// driverArgs converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values.
//
-// The statement si may be nil, if no statement is available.
-func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) {
+// The statement ds may be nil, if no statement is available.
+func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
dargs := make([]driver.Value, len(args))
+ var si driver.Stmt
+ if ds != nil {
+ si = ds.si
+ }
cc, ok := si.(driver.ColumnConverter)
// Normal path, for a driver.Stmt that is not a ColumnConverter.
@@ -60,7 +64,9 @@ func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) {
// column before going across the network to get the
// same error.
var err error
+ ds.Lock()
dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
+ ds.Unlock()
if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
}
@@ -106,25 +112,41 @@ func convertAssign(dest, src interface{}) error {
if d == nil {
return errNilPtr
}
- bcopy := make([]byte, len(s))
- copy(bcopy, s)
- *d = bcopy
+ *d = cloneBytes(s)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
+ *d = cloneBytes(s)
+ return nil
+ case *RawBytes:
+ if d == nil {
+ return errNilPtr
+ }
*d = s
return nil
}
case nil:
switch d := dest.(type) {
+ case *interface{}:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = nil
+ return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = nil
return nil
+ case *RawBytes:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = nil
+ return nil
}
}
@@ -141,6 +163,26 @@ func convertAssign(dest, src interface{}) error {
*d = fmt.Sprintf("%v", src)
return nil
}
+ case *[]byte:
+ sv = reflect.ValueOf(src)
+ switch sv.Kind() {
+ case reflect.Bool,
+ reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
+ reflect.Float32, reflect.Float64:
+ *d = []byte(fmt.Sprintf("%v", src))
+ return nil
+ }
+ case *RawBytes:
+ sv = reflect.ValueOf(src)
+ switch sv.Kind() {
+ case reflect.Bool,
+ reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
+ reflect.Float32, reflect.Float64:
+ *d = RawBytes(fmt.Sprintf("%v", src))
+ return nil
+ }
case *bool:
bv, err := driver.Bool.ConvertValue(src)
if err == nil {
@@ -212,6 +254,16 @@ func convertAssign(dest, src interface{}) error {
return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
}
+func cloneBytes(b []byte) []byte {
+ if b == nil {
+ return nil
+ } else {
+ c := make([]byte, len(b))
+ copy(c, b)
+ return c
+ }
+}
+
func asString(src interface{}) string {
switch v := src.(type) {
case string:
diff --git a/src/pkg/database/sql/convert_test.go b/src/pkg/database/sql/convert_test.go
index 9c362d732..950e24fc3 100644
--- a/src/pkg/database/sql/convert_test.go
+++ b/src/pkg/database/sql/convert_test.go
@@ -22,6 +22,8 @@ type conversionTest struct {
wantint int64
wantuint uint64
wantstr string
+ wantbytes []byte
+ wantraw RawBytes
wantf32 float32
wantf64 float64
wanttime time.Time
@@ -35,6 +37,8 @@ type conversionTest struct {
// Target variables for scanning into.
var (
scanstr string
+ scanbytes []byte
+ scanraw RawBytes
scanint int
scanint8 int8
scanint16 int16
@@ -56,6 +60,7 @@ var conversionTests = []conversionTest{
{s: someTime, d: &scantime, wanttime: someTime},
// To strings
+ {s: "string", d: &scanstr, wantstr: "string"},
{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
{s: 123, d: &scanstr, wantstr: "123"},
{s: int8(123), d: &scanstr, wantstr: "123"},
@@ -66,6 +71,31 @@ var conversionTests = []conversionTest{
{s: uint64(123), d: &scanstr, wantstr: "123"},
{s: 1.5, d: &scanstr, wantstr: "1.5"},
+ // To []byte
+ {s: nil, d: &scanbytes, wantbytes: nil},
+ {s: "string", d: &scanbytes, wantbytes: []byte("string")},
+ {s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")},
+ {s: 123, d: &scanbytes, wantbytes: []byte("123")},
+ {s: int8(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: int64(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: uint8(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: uint16(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: uint32(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: uint64(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")},
+
+ // To RawBytes
+ {s: nil, d: &scanraw, wantraw: nil},
+ {s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
+ {s: 123, d: &scanraw, wantraw: RawBytes("123")},
+ {s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: uint8(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: uint16(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
+
// Strings to integers
{s: "255", d: &scanuint8, wantuint: 255},
{s: "256", d: &scanuint8, wanterr: `converting string "256" to a uint8: strconv.ParseUint: parsing "256": value out of range`},
@@ -113,6 +143,7 @@ var conversionTests = []conversionTest{
{s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
{s: true, d: &scaniface, wantiface: true},
{s: nil, d: &scaniface},
+ {s: []byte(nil), d: &scaniface, wantiface: []byte(nil)},
}
func intPtrValue(intptr interface{}) interface{} {
@@ -191,7 +222,7 @@ func TestConversions(t *testing.T) {
}
if srcBytes, ok := ct.s.([]byte); ok {
dstBytes := (*ifptr).([]byte)
- if &dstBytes[0] == &srcBytes[0] {
+ if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] {
errf("copy into interface{} didn't copy []byte data")
}
}
diff --git a/src/pkg/database/sql/driver/driver.go b/src/pkg/database/sql/driver/driver.go
index 2434e419b..d7ca94f78 100644
--- a/src/pkg/database/sql/driver/driver.go
+++ b/src/pkg/database/sql/driver/driver.go
@@ -10,8 +10,8 @@ package driver
import "errors"
-// A driver Value is a value that drivers must be able to handle.
-// A Value is either nil or an instance of one of these types:
+// Value is a value that drivers must be able to handle.
+// It is either nil or an instance of one of these types:
//
// int64
// float64
diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go
index 55597f7de..d900e2ceb 100644
--- a/src/pkg/database/sql/fakedb_test.go
+++ b/src/pkg/database/sql/fakedb_test.go
@@ -13,6 +13,7 @@ import (
"strconv"
"strings"
"sync"
+ "testing"
"time"
)
@@ -34,9 +35,10 @@ var _ = log.Printf
// When opening a fakeDriver's database, it starts empty with no
// tables. All tables and data are stored in memory only.
type fakeDriver struct {
- mu sync.Mutex
- openCount int
- dbs map[string]*fakeDB
+ mu sync.Mutex // guards 3 following fields
+ openCount int // conn opens
+ closeCount int // conn closes
+ dbs map[string]*fakeDB
}
type fakeDB struct {
@@ -229,7 +231,43 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
return c.currTx, nil
}
-func (c *fakeConn) Close() error {
+var hookPostCloseConn struct {
+ sync.Mutex
+ fn func(*fakeConn, error)
+}
+
+func setHookpostCloseConn(fn func(*fakeConn, error)) {
+ hookPostCloseConn.Lock()
+ defer hookPostCloseConn.Unlock()
+ hookPostCloseConn.fn = fn
+}
+
+var testStrictClose *testing.T
+
+// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
+// fails to close. If nil, the check is disabled.
+func setStrictFakeConnClose(t *testing.T) {
+ testStrictClose = t
+}
+
+func (c *fakeConn) Close() (err error) {
+ drv := fdriver.(*fakeDriver)
+ defer func() {
+ if err != nil && testStrictClose != nil {
+ testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
+ }
+ hookPostCloseConn.Lock()
+ fn := hookPostCloseConn.fn
+ hookPostCloseConn.Unlock()
+ if fn != nil {
+ fn(c, err)
+ }
+ if err == nil {
+ drv.mu.Lock()
+ drv.closeCount++
+ drv.mu.Unlock()
+ }
+ }()
if c.currTx != nil {
return errors.New("can't close fakeConn; in a Transaction")
}
@@ -424,6 +462,12 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
}
func (s *fakeStmt) Close() error {
+ if s.c == nil {
+ panic("nil conn in fakeStmt.Close")
+ }
+ if s.c.db == nil {
+ panic("in fakeStmt.Close, conn's db is nil (already closed)")
+ }
if !s.closed {
s.c.incrStat(&s.c.stmtsClosed)
s.closed = true
@@ -515,6 +559,15 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
if !ok {
return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
}
+
+ if s.table == "magicquery" {
+ if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
+ if args[0] == "sleep" {
+ time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
+ }
+ }
+ }
+
t.mu.Lock()
defer t.mu.Unlock()
diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go
index 4faaa11b1..a80782bfe 100644
--- a/src/pkg/database/sql/sql.go
+++ b/src/pkg/database/sql/sql.go
@@ -4,6 +4,9 @@
// Package sql provides a generic interface around SQL (or SQL-like)
// databases.
+//
+// The sql package must be used in conjunction with a database driver.
+// See http://golang.org/s/sqldrivers for a list of drivers.
package sql
import (
@@ -177,32 +180,139 @@ var ErrNoRows = errors.New("sql: no rows in result set")
// DB is a database handle. It's safe for concurrent use by multiple
// goroutines.
//
-// If the underlying database driver has the concept of a connection
-// and per-connection session state, the sql package manages creating
-// and freeing connections automatically, including maintaining a free
-// pool of idle connections. If observing session state is required,
-// either do not share a *DB between multiple concurrent goroutines or
-// create and observe all state only within a transaction. Once
-// DB.Open is called, the returned Tx is bound to a single isolated
-// connection. Once Tx.Commit or Tx.Rollback is called, that
-// connection is returned to DB's idle connection pool.
+// The sql package creates and frees connections automatically; it
+// also maintains a free pool of idle connections. If the database has
+// a concept of per-connection state, such state can only be reliably
+// observed within a transaction. Once DB.Begin is called, the
+// returned Tx is bound to a single connection. Once Commit or
+// Rollback is called on the transaction, that transaction's
+// connection is returned to DB's idle connection pool. The pool size
+// can be controlled with SetMaxIdleConns.
type DB struct {
driver driver.Driver
dsn string
- mu sync.Mutex // protects following fields
- outConn map[driver.Conn]bool // whether the conn is in use
- freeConn []driver.Conn
- closed bool
- dep map[finalCloser]depSet
- onConnPut map[driver.Conn][]func() // code (with mu held) run when conn is next returned
- lastPut map[driver.Conn]string // stacktrace of last conn's put; debug only
+ mu sync.Mutex // protects following fields
+ freeConn []*driverConn
+ closed bool
+ dep map[finalCloser]depSet
+ lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
+ maxIdle int // zero means defaultMaxIdleConns; negative means 0
+}
+
+// driverConn wraps a driver.Conn with a mutex, to
+// be held during all calls into the Conn. (including any calls onto
+// interfaces returned via that Conn, such as calls on Tx, Stmt,
+// Result, Rows)
+type driverConn struct {
+ db *DB
+
+ sync.Mutex // guards following
+ ci driver.Conn
+ closed bool
+ finalClosed bool // ci.Close has been called
+ openStmt map[driver.Stmt]bool
+
+ // guarded by db.mu
+ inUse bool
+ onPut []func() // code (with db.mu held) run when conn is next returned
+ dbmuClosed bool // same as closed, but guarded by db.mu, for connIfFree
+}
+
+func (dc *driverConn) removeOpenStmt(si driver.Stmt) {
+ dc.Lock()
+ defer dc.Unlock()
+ delete(dc.openStmt, si)
+}
+
+func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
+ si, err := dc.ci.Prepare(query)
+ if err == nil {
+ // Track each driverConn's open statements, so we can close them
+ // before closing the conn.
+ //
+ // TODO(bradfitz): let drivers opt out of caring about
+ // stmt closes if the conn is about to close anyway? For now
+ // do the safe thing, in case stmts need to be closed.
+ //
+ // TODO(bradfitz): after Go 1.1, closing driver.Stmts
+ // should be moved to driverStmt, using unique
+ // *driverStmts everywhere (including from
+ // *Stmt.connStmt, instead of returning a
+ // driver.Stmt), using driverStmt as a pointer
+ // everywhere, and making it a finalCloser.
+ if dc.openStmt == nil {
+ dc.openStmt = make(map[driver.Stmt]bool)
+ }
+ dc.openStmt[si] = true
+ }
+ return si, err
+}
+
+// the dc.db's Mutex is held.
+func (dc *driverConn) closeDBLocked() error {
+ dc.Lock()
+ if dc.closed {
+ dc.Unlock()
+ return errors.New("sql: duplicate driverConn close")
+ }
+ dc.closed = true
+ dc.Unlock() // not defer; removeDep finalClose calls may need to lock
+ return dc.db.removeDepLocked(dc, dc)()
+}
+
+func (dc *driverConn) Close() error {
+ dc.Lock()
+ if dc.closed {
+ dc.Unlock()
+ return errors.New("sql: duplicate driverConn close")
+ }
+ dc.closed = true
+ dc.Unlock() // not defer; removeDep finalClose calls may need to lock
+
+ // And now updates that require holding dc.mu.Lock.
+ dc.db.mu.Lock()
+ dc.dbmuClosed = true
+ fn := dc.db.removeDepLocked(dc, dc)
+ dc.db.mu.Unlock()
+ return fn()
+}
+
+func (dc *driverConn) finalClose() error {
+ dc.Lock()
+
+ for si := range dc.openStmt {
+ si.Close()
+ }
+ dc.openStmt = nil
+
+ err := dc.ci.Close()
+ dc.ci = nil
+ dc.finalClosed = true
+
+ dc.Unlock()
+ return err
+}
+
+// driverStmt associates a driver.Stmt with the
+// *driverConn from which it came, so the driverConn's lock can be
+// held during calls.
+type driverStmt struct {
+ sync.Locker // the *driverConn
+ si driver.Stmt
+}
+
+func (ds *driverStmt) Close() error {
+ ds.Lock()
+ defer ds.Unlock()
+ return ds.si.Close()
}
// depSet is a finalCloser's outstanding dependencies
type depSet map[interface{}]bool // set of true bools
-// The finalCloser interface is used by (*DB).addDep and (*DB).get
+// The finalCloser interface is used by (*DB).addDep and related
+// dependency reference counting.
type finalCloser interface {
// finalClose is called when the reference count of an object
// goes to zero. (*DB).mu is not held while calling it.
@@ -215,6 +325,10 @@ func (db *DB) addDep(x finalCloser, dep interface{}) {
//println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep))
db.mu.Lock()
defer db.mu.Unlock()
+ db.addDepLocked(x, dep)
+}
+
+func (db *DB) addDepLocked(x finalCloser, dep interface{}) {
if db.dep == nil {
db.dep = make(map[finalCloser]depSet)
}
@@ -231,10 +345,16 @@ func (db *DB) addDep(x finalCloser, dep interface{}) {
// If x no longer has any dependencies, its finalClose method will be
// called and its error value will be returned.
func (db *DB) removeDep(x finalCloser, dep interface{}) error {
+ db.mu.Lock()
+ fn := db.removeDepLocked(x, dep)
+ db.mu.Unlock()
+ return fn()
+}
+
+func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error {
//println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep))
done := false
- db.mu.Lock()
xdep := db.dep[x]
if xdep != nil {
delete(xdep, dep)
@@ -243,13 +363,14 @@ func (db *DB) removeDep(x finalCloser, dep interface{}) error {
done = true
}
}
- db.mu.Unlock()
if !done {
- return nil
+ return func() error { return nil }
+ }
+ return func() error {
+ //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
+ return x.finalClose()
}
- //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
- return x.finalClose()
}
// Open opens a database specified by its database driver name and a
@@ -257,31 +378,47 @@ func (db *DB) removeDep(x finalCloser, dep interface{}) error {
// database name and connection information.
//
// Most users will open a database via a driver-specific connection
-// helper function that returns a *DB.
+// helper function that returns a *DB. No database drivers are included
+// in the Go standard library. See http://golang.org/s/sqldrivers for
+// a list of third-party drivers.
+//
+// Open may just validate its arguments without creating a connection
+// to the database. To verify that the data source name is valid, call
+// Ping.
func Open(driverName, dataSourceName string) (*DB, error) {
driveri, ok := drivers[driverName]
if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
- // TODO: optionally proactively connect to a Conn to check
- // the dataSourceName: golang.org/issue/4804
db := &DB{
- driver: driveri,
- dsn: dataSourceName,
- outConn: make(map[driver.Conn]bool),
- lastPut: make(map[driver.Conn]string),
- onConnPut: make(map[driver.Conn][]func()),
+ driver: driveri,
+ dsn: dataSourceName,
+ lastPut: make(map[*driverConn]string),
}
return db, nil
}
+// Ping verifies a connection to the database is still alive,
+// establishing a connection if necessary.
+func (db *DB) Ping() error {
+ // TODO(bradfitz): give drivers an optional hook to implement
+ // this in a more efficient or more reliable way, if they
+ // have one.
+ dc, err := db.conn()
+ if err != nil {
+ return err
+ }
+ db.putConn(dc, nil)
+ return nil
+}
+
// Close closes the database, releasing any open resources.
func (db *DB) Close() error {
db.mu.Lock()
defer db.mu.Unlock()
var err error
- for _, c := range db.freeConn {
- err1 := c.Close()
+ for _, dc := range db.freeConn {
+ err1 := dc.closeDBLocked()
if err1 != nil {
err = err1
}
@@ -291,15 +428,45 @@ func (db *DB) Close() error {
return err
}
-func (db *DB) maxIdleConns() int {
- const defaultMaxIdleConns = 2
- // TODO(bradfitz): ask driver, if supported, for its default preference
- // TODO(bradfitz): let users override?
- return defaultMaxIdleConns
+const defaultMaxIdleConns = 2
+
+func (db *DB) maxIdleConnsLocked() int {
+ n := db.maxIdle
+ switch {
+ case n == 0:
+ // TODO(bradfitz): ask driver, if supported, for its default preference
+ return defaultMaxIdleConns
+ case n < 0:
+ return 0
+ default:
+ return n
+ }
}
-// conn returns a newly-opened or cached driver.Conn
-func (db *DB) conn() (driver.Conn, error) {
+// SetMaxIdleConns sets the maximum number of connections in the idle
+// connection pool.
+//
+// If n <= 0, no idle connections are retained.
+func (db *DB) SetMaxIdleConns(n int) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ if n > 0 {
+ db.maxIdle = n
+ } else {
+ // No idle connections.
+ db.maxIdle = -1
+ }
+ for len(db.freeConn) > 0 && len(db.freeConn) > n {
+ nfree := len(db.freeConn)
+ dc := db.freeConn[nfree-1]
+ db.freeConn[nfree-1] = nil
+ db.freeConn = db.freeConn[:nfree-1]
+ go dc.Close()
+ }
+}
+
+// conn returns a newly-opened or cached *driverConn
+func (db *DB) conn() (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
@@ -308,30 +475,47 @@ func (db *DB) conn() (driver.Conn, error) {
if n := len(db.freeConn); n > 0 {
conn := db.freeConn[n-1]
db.freeConn = db.freeConn[:n-1]
- db.outConn[conn] = true
+ conn.inUse = true
db.mu.Unlock()
return conn, nil
}
db.mu.Unlock()
- conn, err := db.driver.Open(db.dsn)
- if err == nil {
- db.mu.Lock()
- db.outConn[conn] = true
- db.mu.Unlock()
+
+ ci, err := db.driver.Open(db.dsn)
+ if err != nil {
+ return nil, err
+ }
+ dc := &driverConn{
+ db: db,
+ ci: ci,
}
- return conn, err
+ db.mu.Lock()
+ db.addDepLocked(dc, dc)
+ dc.inUse = true
+ db.mu.Unlock()
+ return dc, nil
}
-// connIfFree returns (wanted, true) if wanted is still a valid conn and
+var (
+ errConnClosed = errors.New("database/sql: internal sentinel error: conn is closed")
+ errConnBusy = errors.New("database/sql: internal sentinel error: conn is busy")
+)
+
+// connIfFree returns (wanted, nil) if wanted is still a valid conn and
// isn't in use.
//
-// If wanted is valid but in use, connIfFree returns (wanted, false).
-// If wanted is invalid, connIfFre returns (nil, false).
-func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
+// The error is errConnClosed if the connection if the requested connection
+// is invalid because it's been closed.
+//
+// The error is errConnBusy if the connection is in use.
+func (db *DB) connIfFree(wanted *driverConn) (*driverConn, error) {
db.mu.Lock()
defer db.mu.Unlock()
- if db.outConn[wanted] {
- return conn, false
+ if wanted.inUse {
+ return nil, errConnBusy
+ }
+ if wanted.dbmuClosed {
+ return nil, errConnClosed
}
for i, conn := range db.freeConn {
if conn != wanted {
@@ -339,27 +523,36 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
}
db.freeConn[i] = db.freeConn[len(db.freeConn)-1]
db.freeConn = db.freeConn[:len(db.freeConn)-1]
- db.outConn[wanted] = true
- return wanted, true
+ wanted.inUse = true
+ return wanted, nil
}
- return nil, false
+ // TODO(bradfitz): shouldn't get here. After Go 1.1, change this to:
+ // panic("connIfFree call requested a non-closed, non-busy, non-free conn")
+ // Which passes all the tests, but I'm too paranoid to include this
+ // late in Go 1.1.
+ // Instead, treat it like a busy connection:
+ return nil, errConnBusy
}
// putConnHook is a hook for testing.
-var putConnHook func(*DB, driver.Conn)
+var putConnHook func(*DB, *driverConn)
// noteUnusedDriverStatement notes that si is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is
// already closed.
-func (db *DB) noteUnusedDriverStatement(c driver.Conn, si driver.Stmt) {
+func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) {
db.mu.Lock()
defer db.mu.Unlock()
- if db.outConn[c] {
- db.onConnPut[c] = append(db.onConnPut[c], func() {
+ if c.inUse {
+ c.onPut = append(c.onPut, func() {
si.Close()
})
} else {
- si.Close()
+ c.Lock()
+ defer c.Unlock()
+ if !c.finalClosed {
+ si.Close()
+ }
}
}
@@ -369,25 +562,23 @@ const debugGetPut = false
// putConn adds a connection to the db's free pool.
// err is optionally the last error that occurred on this connection.
-func (db *DB) putConn(c driver.Conn, err error) {
+func (db *DB) putConn(dc *driverConn, err error) {
db.mu.Lock()
- if !db.outConn[c] {
+ if !dc.inUse {
if debugGetPut {
- fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", c, stack(), db.lastPut[c])
+ fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
}
panic("sql: connection returned that was never out")
}
if debugGetPut {
- db.lastPut[c] = stack()
+ db.lastPut[dc] = stack()
}
- delete(db.outConn, c)
+ dc.inUse = false
- if fns, ok := db.onConnPut[c]; ok {
- for _, fn := range fns {
- fn()
- }
- delete(db.onConnPut, c)
+ for _, fn := range dc.onPut {
+ fn()
}
+ dc.onPut = nil
if err == driver.ErrBadConn {
// Don't reuse bad connections.
@@ -395,17 +586,16 @@ func (db *DB) putConn(c driver.Conn, err error) {
return
}
if putConnHook != nil {
- putConnHook(db, c)
+ putConnHook(db, dc)
}
- if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
- db.freeConn = append(db.freeConn, c)
+ if n := len(db.freeConn); !db.closed && n < db.maxIdleConnsLocked() {
+ db.freeConn = append(db.freeConn, dc)
db.mu.Unlock()
return
}
- // TODO: check to see if we need this Conn for any prepared
- // statements which are still active?
db.mu.Unlock()
- c.Close()
+
+ dc.Close()
}
// Prepare creates a prepared statement for later queries or executions.
@@ -430,21 +620,24 @@ func (db *DB) prepare(query string) (*Stmt, error) {
// to a connection, and to execute this prepared statement
// we either need to use this connection (if it's free), else
// get a new connection + re-prepare + execute on that one.
- ci, err := db.conn()
+ dc, err := db.conn()
if err != nil {
return nil, err
}
- si, err := ci.Prepare(query)
+ dc.Lock()
+ si, err := dc.prepareLocked(query)
+ dc.Unlock()
if err != nil {
- db.putConn(ci, err)
+ db.putConn(dc, err)
return nil, err
}
stmt := &Stmt{
db: db,
query: query,
- css: []connStmt{{ci, si}},
+ css: []connStmt{{dc, si}},
}
db.addDep(stmt, stmt)
+ db.putConn(dc, nil)
return stmt, nil
}
@@ -463,35 +656,38 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
}
func (db *DB) exec(query string, args []interface{}) (res Result, err error) {
- ci, err := db.conn()
+ dc, err := db.conn()
if err != nil {
return nil, err
}
defer func() {
- db.putConn(ci, err)
+ db.putConn(dc, err)
}()
- if execer, ok := ci.(driver.Execer); ok {
+ if execer, ok := dc.ci.(driver.Execer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
return nil, err
}
+ dc.Lock()
resi, err := execer.Exec(query, dargs)
+ dc.Unlock()
if err != driver.ErrSkip {
if err != nil {
return nil, err
}
- return result{resi}, nil
+ return driverResult{dc, resi}, nil
}
}
- sti, err := ci.Prepare(query)
+ dc.Lock()
+ si, err := dc.ci.Prepare(query)
+ dc.Unlock()
if err != nil {
return nil, err
}
- defer sti.Close()
-
- return resultFromStatement(sti, args...)
+ defer withLock(dc, func() { si.Close() })
+ return resultFromStatement(driverStmt{dc, si}, args...)
}
// Query executes a query that returns rows, typically a SELECT.
@@ -521,24 +717,25 @@ func (db *DB) query(query string, args []interface{}) (*Rows, error) {
// queryConn executes a query on the given connection.
// The connection gets released by the releaseConn function.
-func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
- if queryer, ok := ci.(driver.Queryer); ok {
+func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
+ if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
releaseConn(err)
return nil, err
}
+ dc.Lock()
rowsi, err := queryer.Query(query, dargs)
+ dc.Unlock()
if err != driver.ErrSkip {
if err != nil {
releaseConn(err)
return nil, err
}
- // Note: ownership of ci passes to the *Rows, to be freed
+ // Note: ownership of dc passes to the *Rows, to be freed
// with releaseConn.
rows := &Rows{
- db: db,
- ci: ci,
+ dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
}
@@ -546,27 +743,31 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a
}
}
- sti, err := ci.Prepare(query)
+ dc.Lock()
+ si, err := dc.ci.Prepare(query)
+ dc.Unlock()
if err != nil {
releaseConn(err)
return nil, err
}
- rowsi, err := rowsiFromStatement(sti, args...)
+ ds := driverStmt{dc, si}
+ rowsi, err := rowsiFromStatement(ds, args...)
if err != nil {
releaseConn(err)
- sti.Close()
+ dc.Lock()
+ si.Close()
+ dc.Unlock()
return nil, err
}
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
rows := &Rows{
- db: db,
- ci: ci,
+ dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
- closeStmt: sti,
+ closeStmt: si,
}
return rows, nil
}
@@ -594,18 +795,20 @@ func (db *DB) Begin() (*Tx, error) {
}
func (db *DB) begin() (tx *Tx, err error) {
- ci, err := db.conn()
+ dc, err := db.conn()
if err != nil {
return nil, err
}
- txi, err := ci.Begin()
+ dc.Lock()
+ txi, err := dc.ci.Begin()
+ dc.Unlock()
if err != nil {
- db.putConn(ci, err)
+ db.putConn(dc, err)
return nil, err
}
return &Tx{
db: db,
- ci: ci,
+ dc: dc,
txi: txi,
}, nil
}
@@ -624,15 +827,11 @@ func (db *DB) Driver() driver.Driver {
type Tx struct {
db *DB
- // ci is owned exclusively until Commit or Rollback, at which point
+ // dc is owned exclusively until Commit or Rollback, at which point
// it's returned with putConn.
- ci driver.Conn
+ dc *driverConn
txi driver.Tx
- // cimu is held while somebody is using ci (between grabConn
- // and releaseConn)
- cimu sync.Mutex
-
// done transitions from false to true exactly once, on Commit
// or Rollback. once done, all operations fail with
// ErrTxDone.
@@ -646,21 +845,16 @@ func (tx *Tx) close() {
panic("double close") // internal error
}
tx.done = true
- tx.db.putConn(tx.ci, nil)
- tx.ci = nil
+ tx.db.putConn(tx.dc, nil)
+ tx.dc = nil
tx.txi = nil
}
-func (tx *Tx) grabConn() (driver.Conn, error) {
+func (tx *Tx) grabConn() (*driverConn, error) {
if tx.done {
return nil, ErrTxDone
}
- tx.cimu.Lock()
- return tx.ci, nil
-}
-
-func (tx *Tx) releaseConn() {
- tx.cimu.Unlock()
+ return tx.dc, nil
}
// Commit commits the transaction.
@@ -669,6 +863,8 @@ func (tx *Tx) Commit() error {
return ErrTxDone
}
defer tx.close()
+ tx.dc.Lock()
+ defer tx.dc.Unlock()
return tx.txi.Commit()
}
@@ -678,6 +874,8 @@ func (tx *Tx) Rollback() error {
return ErrTxDone
}
defer tx.close()
+ tx.dc.Lock()
+ defer tx.dc.Unlock()
return tx.txi.Rollback()
}
@@ -701,21 +899,25 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// Perhaps just looking at the reference count (by noting
// Stmt.Close) would be enough. We might also want a finalizer
// on Stmt to drop the reference count.
- ci, err := tx.grabConn()
+ dc, err := tx.grabConn()
if err != nil {
return nil, err
}
- defer tx.releaseConn()
- si, err := ci.Prepare(query)
+ dc.Lock()
+ si, err := dc.ci.Prepare(query)
+ dc.Unlock()
if err != nil {
return nil, err
}
stmt := &Stmt{
- db: tx.db,
- tx: tx,
- txsi: si,
+ db: tx.db,
+ tx: tx,
+ txsi: &driverStmt{
+ Locker: dc,
+ si: si,
+ },
query: query,
}
return stmt, nil
@@ -739,16 +941,20 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
- ci, err := tx.grabConn()
+ dc, err := tx.grabConn()
if err != nil {
return &Stmt{stickyErr: err}
}
- defer tx.releaseConn()
- si, err := ci.Prepare(stmt.query)
+ dc.Lock()
+ si, err := dc.ci.Prepare(stmt.query)
+ dc.Unlock()
return &Stmt{
- db: tx.db,
- tx: tx,
- txsi: si,
+ db: tx.db,
+ tx: tx,
+ txsi: &driverStmt{
+ Locker: dc,
+ si: si,
+ },
query: stmt.query,
stickyErr: err,
}
@@ -757,45 +963,46 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
// Exec executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
- ci, err := tx.grabConn()
+ dc, err := tx.grabConn()
if err != nil {
return nil, err
}
- defer tx.releaseConn()
- if execer, ok := ci.(driver.Execer); ok {
+ if execer, ok := dc.ci.(driver.Execer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
return nil, err
}
+ dc.Lock()
resi, err := execer.Exec(query, dargs)
+ dc.Unlock()
if err == nil {
- return result{resi}, nil
+ return driverResult{dc, resi}, nil
}
if err != driver.ErrSkip {
return nil, err
}
}
- sti, err := ci.Prepare(query)
+ dc.Lock()
+ si, err := dc.ci.Prepare(query)
+ dc.Unlock()
if err != nil {
return nil, err
}
- defer sti.Close()
+ defer withLock(dc, func() { si.Close() })
- return resultFromStatement(sti, args...)
+ return resultFromStatement(driverStmt{dc, si}, args...)
}
// Query executes a query that returns rows, typically a SELECT.
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
- ci, err := tx.grabConn()
+ dc, err := tx.grabConn()
if err != nil {
return nil, err
}
-
- releaseConn := func(err error) { tx.releaseConn() }
-
- return tx.db.queryConn(ci, releaseConn, query, args)
+ releaseConn := func(error) {}
+ return tx.db.queryConn(dc, releaseConn, query, args)
}
// QueryRow executes a query that is expected to return at most one row.
@@ -808,7 +1015,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
// connStmt is a prepared statement on a particular connection.
type connStmt struct {
- ci driver.Conn
+ dc *driverConn
si driver.Stmt
}
@@ -823,7 +1030,7 @@ type Stmt struct {
// If in a transaction, else both nil:
tx *Tx
- txsi driver.Stmt
+ txsi *driverStmt
mu sync.Mutex // protects the rest of the fields
closed bool
@@ -840,39 +1047,45 @@ type Stmt struct {
func (s *Stmt) Exec(args ...interface{}) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
- _, releaseConn, si, err := s.connStmt()
+ dc, releaseConn, si, err := s.connStmt()
if err != nil {
return nil, err
}
defer releaseConn(nil)
- return resultFromStatement(si, args...)
+ return resultFromStatement(driverStmt{dc, si}, args...)
}
-func resultFromStatement(si driver.Stmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
+ ds.Lock()
+ want := ds.si.NumInput()
+ ds.Unlock()
+
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
- if want := si.NumInput(); want != -1 && len(args) != want {
+ if want != -1 && len(args) != want {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
}
- dargs, err := driverArgs(si, args)
+ dargs, err := driverArgs(&ds, args)
if err != nil {
return nil, err
}
- resi, err := si.Exec(dargs)
+ ds.Lock()
+ resi, err := ds.si.Exec(dargs)
+ ds.Unlock()
if err != nil {
return nil, err
}
- return result{resi}, nil
+ return driverResult{ds.Locker, resi}, nil
}
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
-func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
if err = s.stickyErr; err != nil {
return
}
@@ -891,19 +1104,27 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St
if err != nil {
return
}
- releaseConn = func(error) { s.tx.releaseConn() }
- return ci, releaseConn, s.txsi, nil
+ releaseConn = func(error) {}
+ return ci, releaseConn, s.txsi.si, nil
}
var cs connStmt
match := false
- for _, v := range s.css {
- // TODO(bradfitz): lazily clean up entries in this
- // list with dead conns while enumerating
- if _, match = s.db.connIfFree(v.ci); match {
+ for i := 0; i < len(s.css); i++ {
+ v := s.css[i]
+ _, err := s.db.connIfFree(v.dc)
+ if err == nil {
+ match = true
cs = v
break
}
+ if err == errConnClosed {
+ // Lazily remove dead conn from our freelist.
+ s.css[i] = s.css[len(s.css)-1]
+ s.css = s.css[:len(s.css)-1]
+ i--
+ }
+
}
s.mu.Unlock()
@@ -911,11 +1132,13 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St
// TODO(bradfitz): or wait for one? make configurable later?
if !match {
for i := 0; ; i++ {
- ci, err := s.db.conn()
+ dc, err := s.db.conn()
if err != nil {
return nil, nil, nil, err
}
- si, err := ci.Prepare(s.query)
+ dc.Lock()
+ si, err := dc.prepareLocked(s.query)
+ dc.Unlock()
if err == driver.ErrBadConn && i < 10 {
continue
}
@@ -923,14 +1146,14 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St
return nil, nil, nil, err
}
s.mu.Lock()
- cs = connStmt{ci, si}
+ cs = connStmt{dc, si}
s.css = append(s.css, cs)
s.mu.Unlock()
break
}
}
- conn := cs.ci
+ conn := cs.dc
releaseConn = func(err error) { s.db.putConn(conn, err) }
return conn, releaseConn, cs.si, nil
}
@@ -941,12 +1164,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
- ci, releaseConn, si, err := s.connStmt()
+ dc, releaseConn, si, err := s.connStmt()
if err != nil {
return nil, err
}
- rowsi, err := rowsiFromStatement(si, args...)
+ ds := driverStmt{dc, si}
+ rowsi, err := rowsiFromStatement(ds, args...)
if err != nil {
releaseConn(err)
return nil, err
@@ -955,8 +1179,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
rows := &Rows{
- db: s.db,
- ci: ci,
+ dc: dc,
rowsi: rowsi,
// releaseConn set below
}
@@ -968,20 +1191,26 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return rows, nil
}
-func rowsiFromStatement(si driver.Stmt, args ...interface{}) (driver.Rows, error) {
+func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
+ ds.Lock()
+ want := ds.si.NumInput()
+ ds.Unlock()
+
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
- if want := si.NumInput(); want != -1 && len(args) != want {
- return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", si.NumInput(), len(args))
+ if want != -1 && len(args) != want {
+ return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
}
- dargs, err := driverArgs(si, args)
+ dargs, err := driverArgs(&ds, args)
if err != nil {
return nil, err
}
- rowsi, err := si.Query(dargs)
+ ds.Lock()
+ rowsi, err := ds.si.Query(dargs)
+ ds.Unlock()
if err != nil {
return nil, err
}
@@ -1032,7 +1261,9 @@ func (s *Stmt) Close() error {
func (s *Stmt) finalClose() error {
for _, v := range s.css {
- s.db.noteUnusedDriverStatement(v.ci, v.si)
+ s.db.noteUnusedDriverStatement(v.dc, v.si)
+ v.dc.removeOpenStmt(v.si)
+ s.db.removeDep(v.dc, s)
}
s.css = nil
return nil
@@ -1052,8 +1283,7 @@ func (s *Stmt) finalClose() error {
// err = rows.Err() // get any error encountered during iteration
// ...
type Rows struct {
- db *DB
- ci driver.Conn // owned; must call releaseConn when closed to release
+ dc *driverConn // owned; must call releaseConn when closed to release
releaseConn func(error)
rowsi driver.Rows
@@ -1136,24 +1366,6 @@ func (rs *Rows) Scan(dest ...interface{}) error {
return fmt.Errorf("sql: Scan error on column index %d: %v", i, err)
}
}
- for _, dp := range dest {
- b, ok := dp.(*[]byte)
- if !ok {
- continue
- }
- if *b == nil {
- // If the []byte is now nil (for a NULL value),
- // don't fall through to below which would
- // turn it into a non-nil 0-length byte slice
- continue
- }
- if _, ok = dp.(*RawBytes); ok {
- continue
- }
- clone := make([]byte, len(*b))
- copy(clone, *b)
- *b = clone
- }
return nil
}
@@ -1166,10 +1378,10 @@ func (rs *Rows) Close() error {
}
rs.closed = true
err := rs.rowsi.Close()
- rs.releaseConn(err)
if rs.closeStmt != nil {
rs.closeStmt.Close()
}
+ rs.releaseConn(err)
return err
}
@@ -1226,11 +1438,31 @@ type Result interface {
RowsAffected() (int64, error)
}
-type result struct {
- driver.Result
+type driverResult struct {
+ sync.Locker // the *driverConn
+ resi driver.Result
+}
+
+func (dr driverResult) LastInsertId() (int64, error) {
+ dr.Lock()
+ defer dr.Unlock()
+ return dr.resi.LastInsertId()
+}
+
+func (dr driverResult) RowsAffected() (int64, error) {
+ dr.Lock()
+ defer dr.Unlock()
+ return dr.resi.RowsAffected()
}
func stack() string {
- var buf [1024]byte
+ var buf [2 << 10]byte
return string(buf[:runtime.Stack(buf[:], false)])
}
+
+// withLock runs while holding lk.
+func withLock(lk sync.Locker, fn func()) {
+ lk.Lock()
+ fn()
+ lk.Unlock()
+}
diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go
index 53b229600..e6cc667fa 100644
--- a/src/pkg/database/sql/sql_test.go
+++ b/src/pkg/database/sql/sql_test.go
@@ -5,10 +5,11 @@
package sql
import (
- "database/sql/driver"
"fmt"
"reflect"
+ "runtime"
"strings"
+ "sync"
"testing"
"time"
)
@@ -16,10 +17,10 @@ import (
func init() {
type dbConn struct {
db *DB
- c driver.Conn
+ c *driverConn
}
freedFrom := make(map[dbConn]string)
- putConnHook = func(db *DB, c driver.Conn) {
+ putConnHook = func(db *DB, c *driverConn) {
for _, oc := range db.freeConn {
if oc == c {
// print before panic, as panic may get lost due to conflicting panic
@@ -37,7 +38,15 @@ const fakeDBName = "foo"
var chrisBirthday = time.Unix(123456789, 0)
-func newTestDB(t *testing.T, name string) *DB {
+type testOrBench interface {
+ Fatalf(string, ...interface{})
+ Errorf(string, ...interface{})
+ Fatal(...interface{})
+ Error(...interface{})
+ Logf(string, ...interface{})
+}
+
+func newTestDB(t testOrBench, name string) *DB {
db, err := Open("test", fakeDBName)
if err != nil {
t.Fatalf("Open: %v", err)
@@ -51,21 +60,42 @@ func newTestDB(t *testing.T, name string) *DB {
exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2)
exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
}
+ if name == "magicquery" {
+ // Magic table name and column, known by fakedb_test.go.
+ exec(t, db, "CREATE|magicquery|op=string,millis=int32")
+ exec(t, db, "INSERT|magicquery|op=sleep,millis=10")
+ }
return db
}
-func exec(t *testing.T, db *DB, query string, args ...interface{}) {
+func exec(t testOrBench, db *DB, query string, args ...interface{}) {
_, err := db.Exec(query, args...)
if err != nil {
t.Fatalf("Exec of %q: %v", query, err)
}
}
-func closeDB(t *testing.T, db *DB) {
+func closeDB(t testOrBench, db *DB) {
if e := recover(); e != nil {
fmt.Printf("Panic: %v\n", e)
panic(e)
}
+ defer setHookpostCloseConn(nil)
+ setHookpostCloseConn(func(_ *fakeConn, err error) {
+ if err != nil {
+ t.Errorf("Error closing fakeConn: %v", err)
+ }
+ })
+ for i, dc := range db.freeConn {
+ if n := len(dc.openStmt); n > 0 {
+ // Just a sanity check. This is legal in
+ // general, but if we make the tests clean up
+ // their statements first, then we can safely
+ // verify this is always zero here, and any
+ // other value is a leak.
+ t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n)
+ }
+ }
err := db.Close()
if err != nil {
t.Fatalf("error closing DB: %v", err)
@@ -78,7 +108,52 @@ func numPrepares(t *testing.T, db *DB) int {
if n := len(db.freeConn); n != 1 {
t.Fatalf("free conns = %d; want 1", n)
}
- return db.freeConn[0].(*fakeConn).numPrepare
+ return db.freeConn[0].ci.(*fakeConn).numPrepare
+}
+
+func (db *DB) numDeps() int {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ return len(db.dep)
+}
+
+// Dependencies are closed via a goroutine, so this polls waiting for
+// numDeps to fall to want, waiting up to d.
+func (db *DB) numDepsPollUntil(want int, d time.Duration) int {
+ deadline := time.Now().Add(d)
+ for {
+ n := db.numDeps()
+ if n <= want || time.Now().After(deadline) {
+ return n
+ }
+ time.Sleep(50 * time.Millisecond)
+ }
+}
+
+func (db *DB) numFreeConns() int {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ return len(db.freeConn)
+}
+
+func (db *DB) dumpDeps(t *testing.T) {
+ for fc := range db.dep {
+ db.dumpDep(t, 0, fc, map[finalCloser]bool{})
+ }
+}
+
+func (db *DB) dumpDep(t *testing.T, depth int, dep finalCloser, seen map[finalCloser]bool) {
+ seen[dep] = true
+ indent := strings.Repeat(" ", depth)
+ ds := db.dep[dep]
+ for k := range ds {
+ t.Logf("%s%T (%p) waiting for -> %T (%p)", indent, dep, dep, k, k)
+ if fc, ok := k.(finalCloser); ok {
+ if !seen[fc] {
+ db.dumpDep(t, depth+1, fc, seen)
+ }
+ }
+ }
}
func TestQuery(t *testing.T) {
@@ -117,7 +192,7 @@ func TestQuery(t *testing.T) {
// And verify that the final rows.Next() call, which hit EOF,
// also closed the rows connection.
- if n := len(db.freeConn); n != 1 {
+ if n := db.numFreeConns(); n != 1 {
t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
@@ -576,7 +651,7 @@ func TestQueryRowClosingStmt(t *testing.T) {
if len(db.freeConn) != 1 {
t.Fatalf("expected 1 free conn")
}
- fakeConn := db.freeConn[0].(*fakeConn)
+ fakeConn := db.freeConn[0].ci.(*fakeConn)
if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
}
@@ -708,3 +783,326 @@ func TestQueryRowNilScanDest(t *testing.T) {
t.Errorf("error = %q; want %q", err.Error(), want)
}
}
+
+func TestIssue4902(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ driver := db.driver.(*fakeDriver)
+ opens0 := driver.openCount
+
+ var stmt *Stmt
+ var err error
+ for i := 0; i < 10; i++ {
+ stmt, err = db.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = stmt.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ opens := driver.openCount - opens0
+ if opens > 1 {
+ t.Errorf("opens = %d; want <= 1", opens)
+ t.Logf("db = %#v", db)
+ t.Logf("driver = %#v", driver)
+ t.Logf("stmt = %#v", stmt)
+ }
+}
+
+// Issue 3857
+// This used to deadlock.
+func TestSimultaneousQueries(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tx.Rollback()
+
+ r1, err := tx.Query("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer r1.Close()
+
+ r2, err := tx.Query("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer r2.Close()
+}
+
+func TestMaxIdleConns(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tx.Commit()
+ if got := len(db.freeConn); got != 1 {
+ t.Errorf("freeConns = %d; want 1", got)
+ }
+
+ db.SetMaxIdleConns(0)
+
+ if got := len(db.freeConn); got != 0 {
+ t.Errorf("freeConns after set to zero = %d; want 0", got)
+ }
+
+ tx, err = db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tx.Commit()
+ if got := len(db.freeConn); got != 0 {
+ t.Errorf("freeConns = %d; want 0", got)
+ }
+}
+
+// golang.org/issue/5323
+func TestStmtCloseDeps(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer setHookpostCloseConn(nil)
+ setHookpostCloseConn(func(_ *fakeConn, err error) {
+ if err != nil {
+ t.Errorf("Error closing fakeConn: %v", err)
+ }
+ })
+
+ db := newTestDB(t, "magicquery")
+ defer closeDB(t, db)
+
+ driver := db.driver.(*fakeDriver)
+
+ driver.mu.Lock()
+ opens0 := driver.openCount
+ closes0 := driver.closeCount
+ driver.mu.Unlock()
+ openDelta0 := opens0 - closes0
+
+ stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Start 50 parallel slow queries.
+ const (
+ nquery = 50
+ sleepMillis = 25
+ nbatch = 2
+ )
+ var wg sync.WaitGroup
+ for batch := 0; batch < nbatch; batch++ {
+ for i := 0; i < nquery; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ var op string
+ if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
+ t.Error(err)
+ }
+ }()
+ }
+ // Sleep for twice the expected length of time for the
+ // batch of 50 queries above to finish before starting
+ // the next round.
+ time.Sleep(2 * sleepMillis * time.Millisecond)
+ }
+ wg.Wait()
+
+ if g, w := db.numFreeConns(), 2; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPollUntil(4, time.Second); n > 4 {
+ t.Errorf("number of dependencies = %d; expected <= 4", n)
+ db.dumpDeps(t)
+ }
+
+ driver.mu.Lock()
+ opens := driver.openCount - opens0
+ closes := driver.closeCount - closes0
+ driver.mu.Unlock()
+ openDelta := (driver.openCount - driver.closeCount) - openDelta0
+
+ if openDelta > 2 {
+ t.Logf("open calls = %d", opens)
+ t.Logf("close calls = %d", closes)
+ t.Logf("open delta = %d", openDelta)
+ t.Errorf("db connections opened = %d; want <= 2", openDelta)
+ db.dumpDeps(t)
+ }
+
+ if len(stmt.css) > nquery {
+ t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
+ }
+
+ if err := stmt.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if g, w := db.numFreeConns(), 2; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPollUntil(2, time.Second); n > 2 {
+ t.Errorf("number of dependencies = %d; expected <= 2", n)
+ db.dumpDeps(t)
+ }
+
+ db.SetMaxIdleConns(0)
+
+ if g, w := db.numFreeConns(), 0; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPollUntil(0, time.Second); n > 0 {
+ t.Errorf("number of dependencies = %d; expected 0", n)
+ db.dumpDeps(t)
+ }
+}
+
+// golang.org/issue/5046
+func TestCloseConnBeforeStmts(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ defer setHookpostCloseConn(nil)
+ setHookpostCloseConn(func(_ *fakeConn, err error) {
+ if err != nil {
+ t.Errorf("Error closing fakeConn: %v; from %s", err, stack())
+ db.dumpDeps(t)
+ t.Errorf("DB = %#v", db)
+ }
+ })
+
+ stmt, err := db.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(db.freeConn) != 1 {
+ t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
+ }
+ dc := db.freeConn[0]
+ if dc.closed {
+ t.Errorf("conn shouldn't be closed")
+ }
+
+ if n := len(dc.openStmt); n != 1 {
+ t.Errorf("driverConn num openStmt = %d; want 1", n)
+ }
+ err = db.Close()
+ if err != nil {
+ t.Errorf("db Close = %v", err)
+ }
+ if !dc.closed {
+ t.Errorf("after db.Close, driverConn should be closed")
+ }
+ if n := len(dc.openStmt); n != 0 {
+ t.Errorf("driverConn num openStmt = %d; want 0", n)
+ }
+
+ err = stmt.Close()
+ if err != nil {
+ t.Errorf("Stmt close = %v", err)
+ }
+
+ if !dc.closed {
+ t.Errorf("conn should be closed")
+ }
+ if dc.ci != nil {
+ t.Errorf("after Stmt Close, driverConn's Conn interface should be nil")
+ }
+}
+
+// golang.org/issue/5283: don't release the Rows' connection in Close
+// before calling Stmt.Close.
+func TestRowsCloseOrder(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ db.SetMaxIdleConns(0)
+ setStrictFakeConnClose(t)
+ defer setStrictFakeConnClose(nil)
+
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = rows.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func manyConcurrentQueries(t testOrBench) {
+ maxProcs, numReqs := 16, 500
+ if testing.Short() {
+ maxProcs, numReqs = 4, 50
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ stmt, err := db.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stmt.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(numReqs)
+
+ reqs := make(chan bool)
+ defer close(reqs)
+
+ for i := 0; i < maxProcs*2; i++ {
+ go func() {
+ for _ = range reqs {
+ rows, err := stmt.Query()
+ if err != nil {
+ t.Errorf("error on query: %v", err)
+ wg.Done()
+ continue
+ }
+
+ var name string
+ for rows.Next() {
+ rows.Scan(&name)
+ }
+ rows.Close()
+
+ wg.Done()
+ }
+ }()
+ }
+
+ for i := 0; i < numReqs; i++ {
+ reqs <- true
+ }
+
+ wg.Wait()
+}
+
+func TestConcurrency(t *testing.T) {
+ manyConcurrentQueries(t)
+}
+
+func BenchmarkConcurrency(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ manyConcurrentQueries(b)
+ }
+}