diff options
Diffstat (limited to 'src/pkg/database')
-rw-r--r-- | src/pkg/database/sql/convert.go | 62 | ||||
-rw-r--r-- | src/pkg/database/sql/convert_test.go | 33 | ||||
-rw-r--r-- | src/pkg/database/sql/driver/driver.go | 4 | ||||
-rw-r--r-- | src/pkg/database/sql/fakedb_test.go | 61 | ||||
-rw-r--r-- | src/pkg/database/sql/sql.go | 624 | ||||
-rw-r--r-- | src/pkg/database/sql/sql_test.go | 416 |
6 files changed, 983 insertions, 217 deletions
diff --git a/src/pkg/database/sql/convert.go b/src/pkg/database/sql/convert.go index 853a7826c..c04adde1f 100644 --- a/src/pkg/database/sql/convert.go +++ b/src/pkg/database/sql/convert.go @@ -19,9 +19,13 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript // driverArgs converts arguments from callers of Stmt.Exec and // Stmt.Query into driver Values. // -// The statement si may be nil, if no statement is available. -func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) { +// The statement ds may be nil, if no statement is available. +func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) { dargs := make([]driver.Value, len(args)) + var si driver.Stmt + if ds != nil { + si = ds.si + } cc, ok := si.(driver.ColumnConverter) // Normal path, for a driver.Stmt that is not a ColumnConverter. @@ -60,7 +64,9 @@ func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) { // column before going across the network to get the // same error. var err error + ds.Lock() dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg) + ds.Unlock() if err != nil { return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err) } @@ -106,25 +112,41 @@ func convertAssign(dest, src interface{}) error { if d == nil { return errNilPtr } - bcopy := make([]byte, len(s)) - copy(bcopy, s) - *d = bcopy + *d = cloneBytes(s) return nil case *[]byte: if d == nil { return errNilPtr } + *d = cloneBytes(s) + return nil + case *RawBytes: + if d == nil { + return errNilPtr + } *d = s return nil } case nil: switch d := dest.(type) { + case *interface{}: + if d == nil { + return errNilPtr + } + *d = nil + return nil case *[]byte: if d == nil { return errNilPtr } *d = nil return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = nil + return nil } } @@ -141,6 +163,26 @@ func convertAssign(dest, src interface{}) error { *d = fmt.Sprintf("%v", src) return nil } + case *[]byte: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = []byte(fmt.Sprintf("%v", src)) + return nil + } + case *RawBytes: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = RawBytes(fmt.Sprintf("%v", src)) + return nil + } case *bool: bv, err := driver.Bool.ConvertValue(src) if err == nil { @@ -212,6 +254,16 @@ func convertAssign(dest, src interface{}) error { return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) } +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } else { + c := make([]byte, len(b)) + copy(c, b) + return c + } +} + func asString(src interface{}) string { switch v := src.(type) { case string: diff --git a/src/pkg/database/sql/convert_test.go b/src/pkg/database/sql/convert_test.go index 9c362d732..950e24fc3 100644 --- a/src/pkg/database/sql/convert_test.go +++ b/src/pkg/database/sql/convert_test.go @@ -22,6 +22,8 @@ type conversionTest struct { wantint int64 wantuint uint64 wantstr string + wantbytes []byte + wantraw RawBytes wantf32 float32 wantf64 float64 wanttime time.Time @@ -35,6 +37,8 @@ type conversionTest struct { // Target variables for scanning into. var ( scanstr string + scanbytes []byte + scanraw RawBytes scanint int scanint8 int8 scanint16 int16 @@ -56,6 +60,7 @@ var conversionTests = []conversionTest{ {s: someTime, d: &scantime, wanttime: someTime}, // To strings + {s: "string", d: &scanstr, wantstr: "string"}, {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, {s: 123, d: &scanstr, wantstr: "123"}, {s: int8(123), d: &scanstr, wantstr: "123"}, @@ -66,6 +71,31 @@ var conversionTests = []conversionTest{ {s: uint64(123), d: &scanstr, wantstr: "123"}, {s: 1.5, d: &scanstr, wantstr: "1.5"}, + // To []byte + {s: nil, d: &scanbytes, wantbytes: nil}, + {s: "string", d: &scanbytes, wantbytes: []byte("string")}, + {s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")}, + {s: 123, d: &scanbytes, wantbytes: []byte("123")}, + {s: int8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: int64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint16(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint32(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")}, + + // To RawBytes + {s: nil, d: &scanraw, wantraw: nil}, + {s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")}, + {s: 123, d: &scanraw, wantraw: RawBytes("123")}, + {s: int8(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: int64(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint8(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint16(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint32(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint64(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")}, + // Strings to integers {s: "255", d: &scanuint8, wantuint: 255}, {s: "256", d: &scanuint8, wanterr: `converting string "256" to a uint8: strconv.ParseUint: parsing "256": value out of range`}, @@ -113,6 +143,7 @@ var conversionTests = []conversionTest{ {s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")}, {s: true, d: &scaniface, wantiface: true}, {s: nil, d: &scaniface}, + {s: []byte(nil), d: &scaniface, wantiface: []byte(nil)}, } func intPtrValue(intptr interface{}) interface{} { @@ -191,7 +222,7 @@ func TestConversions(t *testing.T) { } if srcBytes, ok := ct.s.([]byte); ok { dstBytes := (*ifptr).([]byte) - if &dstBytes[0] == &srcBytes[0] { + if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] { errf("copy into interface{} didn't copy []byte data") } } diff --git a/src/pkg/database/sql/driver/driver.go b/src/pkg/database/sql/driver/driver.go index 2434e419b..d7ca94f78 100644 --- a/src/pkg/database/sql/driver/driver.go +++ b/src/pkg/database/sql/driver/driver.go @@ -10,8 +10,8 @@ package driver import "errors" -// A driver Value is a value that drivers must be able to handle. -// A Value is either nil or an instance of one of these types: +// Value is a value that drivers must be able to handle. +// It is either nil or an instance of one of these types: // // int64 // float64 diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go index 55597f7de..d900e2ceb 100644 --- a/src/pkg/database/sql/fakedb_test.go +++ b/src/pkg/database/sql/fakedb_test.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "testing" "time" ) @@ -34,9 +35,10 @@ var _ = log.Printf // When opening a fakeDriver's database, it starts empty with no // tables. All tables and data are stored in memory only. type fakeDriver struct { - mu sync.Mutex - openCount int - dbs map[string]*fakeDB + mu sync.Mutex // guards 3 following fields + openCount int // conn opens + closeCount int // conn closes + dbs map[string]*fakeDB } type fakeDB struct { @@ -229,7 +231,43 @@ func (c *fakeConn) Begin() (driver.Tx, error) { return c.currTx, nil } -func (c *fakeConn) Close() error { +var hookPostCloseConn struct { + sync.Mutex + fn func(*fakeConn, error) +} + +func setHookpostCloseConn(fn func(*fakeConn, error)) { + hookPostCloseConn.Lock() + defer hookPostCloseConn.Unlock() + hookPostCloseConn.fn = fn +} + +var testStrictClose *testing.T + +// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close +// fails to close. If nil, the check is disabled. +func setStrictFakeConnClose(t *testing.T) { + testStrictClose = t +} + +func (c *fakeConn) Close() (err error) { + drv := fdriver.(*fakeDriver) + defer func() { + if err != nil && testStrictClose != nil { + testStrictClose.Errorf("failed to close a test fakeConn: %v", err) + } + hookPostCloseConn.Lock() + fn := hookPostCloseConn.fn + hookPostCloseConn.Unlock() + if fn != nil { + fn(c, err) + } + if err == nil { + drv.mu.Lock() + drv.closeCount++ + drv.mu.Unlock() + } + }() if c.currTx != nil { return errors.New("can't close fakeConn; in a Transaction") } @@ -424,6 +462,12 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { } func (s *fakeStmt) Close() error { + if s.c == nil { + panic("nil conn in fakeStmt.Close") + } + if s.c.db == nil { + panic("in fakeStmt.Close, conn's db is nil (already closed)") + } if !s.closed { s.c.incrStat(&s.c.stmtsClosed) s.closed = true @@ -515,6 +559,15 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { if !ok { return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) } + + if s.table == "magicquery" { + if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { + if args[0] == "sleep" { + time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) + } + } + } + t.mu.Lock() defer t.mu.Unlock() diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go index 4faaa11b1..a80782bfe 100644 --- a/src/pkg/database/sql/sql.go +++ b/src/pkg/database/sql/sql.go @@ -4,6 +4,9 @@ // Package sql provides a generic interface around SQL (or SQL-like) // databases. +// +// The sql package must be used in conjunction with a database driver. +// See http://golang.org/s/sqldrivers for a list of drivers. package sql import ( @@ -177,32 +180,139 @@ var ErrNoRows = errors.New("sql: no rows in result set") // DB is a database handle. It's safe for concurrent use by multiple // goroutines. // -// If the underlying database driver has the concept of a connection -// and per-connection session state, the sql package manages creating -// and freeing connections automatically, including maintaining a free -// pool of idle connections. If observing session state is required, -// either do not share a *DB between multiple concurrent goroutines or -// create and observe all state only within a transaction. Once -// DB.Open is called, the returned Tx is bound to a single isolated -// connection. Once Tx.Commit or Tx.Rollback is called, that -// connection is returned to DB's idle connection pool. +// The sql package creates and frees connections automatically; it +// also maintains a free pool of idle connections. If the database has +// a concept of per-connection state, such state can only be reliably +// observed within a transaction. Once DB.Begin is called, the +// returned Tx is bound to a single connection. Once Commit or +// Rollback is called on the transaction, that transaction's +// connection is returned to DB's idle connection pool. The pool size +// can be controlled with SetMaxIdleConns. type DB struct { driver driver.Driver dsn string - mu sync.Mutex // protects following fields - outConn map[driver.Conn]bool // whether the conn is in use - freeConn []driver.Conn - closed bool - dep map[finalCloser]depSet - onConnPut map[driver.Conn][]func() // code (with mu held) run when conn is next returned - lastPut map[driver.Conn]string // stacktrace of last conn's put; debug only + mu sync.Mutex // protects following fields + freeConn []*driverConn + closed bool + dep map[finalCloser]depSet + lastPut map[*driverConn]string // stacktrace of last conn's put; debug only + maxIdle int // zero means defaultMaxIdleConns; negative means 0 +} + +// driverConn wraps a driver.Conn with a mutex, to +// be held during all calls into the Conn. (including any calls onto +// interfaces returned via that Conn, such as calls on Tx, Stmt, +// Result, Rows) +type driverConn struct { + db *DB + + sync.Mutex // guards following + ci driver.Conn + closed bool + finalClosed bool // ci.Close has been called + openStmt map[driver.Stmt]bool + + // guarded by db.mu + inUse bool + onPut []func() // code (with db.mu held) run when conn is next returned + dbmuClosed bool // same as closed, but guarded by db.mu, for connIfFree +} + +func (dc *driverConn) removeOpenStmt(si driver.Stmt) { + dc.Lock() + defer dc.Unlock() + delete(dc.openStmt, si) +} + +func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) { + si, err := dc.ci.Prepare(query) + if err == nil { + // Track each driverConn's open statements, so we can close them + // before closing the conn. + // + // TODO(bradfitz): let drivers opt out of caring about + // stmt closes if the conn is about to close anyway? For now + // do the safe thing, in case stmts need to be closed. + // + // TODO(bradfitz): after Go 1.1, closing driver.Stmts + // should be moved to driverStmt, using unique + // *driverStmts everywhere (including from + // *Stmt.connStmt, instead of returning a + // driver.Stmt), using driverStmt as a pointer + // everywhere, and making it a finalCloser. + if dc.openStmt == nil { + dc.openStmt = make(map[driver.Stmt]bool) + } + dc.openStmt[si] = true + } + return si, err +} + +// the dc.db's Mutex is held. +func (dc *driverConn) closeDBLocked() error { + dc.Lock() + if dc.closed { + dc.Unlock() + return errors.New("sql: duplicate driverConn close") + } + dc.closed = true + dc.Unlock() // not defer; removeDep finalClose calls may need to lock + return dc.db.removeDepLocked(dc, dc)() +} + +func (dc *driverConn) Close() error { + dc.Lock() + if dc.closed { + dc.Unlock() + return errors.New("sql: duplicate driverConn close") + } + dc.closed = true + dc.Unlock() // not defer; removeDep finalClose calls may need to lock + + // And now updates that require holding dc.mu.Lock. + dc.db.mu.Lock() + dc.dbmuClosed = true + fn := dc.db.removeDepLocked(dc, dc) + dc.db.mu.Unlock() + return fn() +} + +func (dc *driverConn) finalClose() error { + dc.Lock() + + for si := range dc.openStmt { + si.Close() + } + dc.openStmt = nil + + err := dc.ci.Close() + dc.ci = nil + dc.finalClosed = true + + dc.Unlock() + return err +} + +// driverStmt associates a driver.Stmt with the +// *driverConn from which it came, so the driverConn's lock can be +// held during calls. +type driverStmt struct { + sync.Locker // the *driverConn + si driver.Stmt +} + +func (ds *driverStmt) Close() error { + ds.Lock() + defer ds.Unlock() + return ds.si.Close() } // depSet is a finalCloser's outstanding dependencies type depSet map[interface{}]bool // set of true bools -// The finalCloser interface is used by (*DB).addDep and (*DB).get +// The finalCloser interface is used by (*DB).addDep and related +// dependency reference counting. type finalCloser interface { // finalClose is called when the reference count of an object // goes to zero. (*DB).mu is not held while calling it. @@ -215,6 +325,10 @@ func (db *DB) addDep(x finalCloser, dep interface{}) { //println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep)) db.mu.Lock() defer db.mu.Unlock() + db.addDepLocked(x, dep) +} + +func (db *DB) addDepLocked(x finalCloser, dep interface{}) { if db.dep == nil { db.dep = make(map[finalCloser]depSet) } @@ -231,10 +345,16 @@ func (db *DB) addDep(x finalCloser, dep interface{}) { // If x no longer has any dependencies, its finalClose method will be // called and its error value will be returned. func (db *DB) removeDep(x finalCloser, dep interface{}) error { + db.mu.Lock() + fn := db.removeDepLocked(x, dep) + db.mu.Unlock() + return fn() +} + +func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error { //println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep)) done := false - db.mu.Lock() xdep := db.dep[x] if xdep != nil { delete(xdep, dep) @@ -243,13 +363,14 @@ func (db *DB) removeDep(x finalCloser, dep interface{}) error { done = true } } - db.mu.Unlock() if !done { - return nil + return func() error { return nil } + } + return func() error { + //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x)) + return x.finalClose() } - //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x)) - return x.finalClose() } // Open opens a database specified by its database driver name and a @@ -257,31 +378,47 @@ func (db *DB) removeDep(x finalCloser, dep interface{}) error { // database name and connection information. // // Most users will open a database via a driver-specific connection -// helper function that returns a *DB. +// helper function that returns a *DB. No database drivers are included +// in the Go standard library. See http://golang.org/s/sqldrivers for +// a list of third-party drivers. +// +// Open may just validate its arguments without creating a connection +// to the database. To verify that the data source name is valid, call +// Ping. func Open(driverName, dataSourceName string) (*DB, error) { driveri, ok := drivers[driverName] if !ok { return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) } - // TODO: optionally proactively connect to a Conn to check - // the dataSourceName: golang.org/issue/4804 db := &DB{ - driver: driveri, - dsn: dataSourceName, - outConn: make(map[driver.Conn]bool), - lastPut: make(map[driver.Conn]string), - onConnPut: make(map[driver.Conn][]func()), + driver: driveri, + dsn: dataSourceName, + lastPut: make(map[*driverConn]string), } return db, nil } +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (db *DB) Ping() error { + // TODO(bradfitz): give drivers an optional hook to implement + // this in a more efficient or more reliable way, if they + // have one. + dc, err := db.conn() + if err != nil { + return err + } + db.putConn(dc, nil) + return nil +} + // Close closes the database, releasing any open resources. func (db *DB) Close() error { db.mu.Lock() defer db.mu.Unlock() var err error - for _, c := range db.freeConn { - err1 := c.Close() + for _, dc := range db.freeConn { + err1 := dc.closeDBLocked() if err1 != nil { err = err1 } @@ -291,15 +428,45 @@ func (db *DB) Close() error { return err } -func (db *DB) maxIdleConns() int { - const defaultMaxIdleConns = 2 - // TODO(bradfitz): ask driver, if supported, for its default preference - // TODO(bradfitz): let users override? - return defaultMaxIdleConns +const defaultMaxIdleConns = 2 + +func (db *DB) maxIdleConnsLocked() int { + n := db.maxIdle + switch { + case n == 0: + // TODO(bradfitz): ask driver, if supported, for its default preference + return defaultMaxIdleConns + case n < 0: + return 0 + default: + return n + } } -// conn returns a newly-opened or cached driver.Conn -func (db *DB) conn() (driver.Conn, error) { +// SetMaxIdleConns sets the maximum number of connections in the idle +// connection pool. +// +// If n <= 0, no idle connections are retained. +func (db *DB) SetMaxIdleConns(n int) { + db.mu.Lock() + defer db.mu.Unlock() + if n > 0 { + db.maxIdle = n + } else { + // No idle connections. + db.maxIdle = -1 + } + for len(db.freeConn) > 0 && len(db.freeConn) > n { + nfree := len(db.freeConn) + dc := db.freeConn[nfree-1] + db.freeConn[nfree-1] = nil + db.freeConn = db.freeConn[:nfree-1] + go dc.Close() + } +} + +// conn returns a newly-opened or cached *driverConn +func (db *DB) conn() (*driverConn, error) { db.mu.Lock() if db.closed { db.mu.Unlock() @@ -308,30 +475,47 @@ func (db *DB) conn() (driver.Conn, error) { if n := len(db.freeConn); n > 0 { conn := db.freeConn[n-1] db.freeConn = db.freeConn[:n-1] - db.outConn[conn] = true + conn.inUse = true db.mu.Unlock() return conn, nil } db.mu.Unlock() - conn, err := db.driver.Open(db.dsn) - if err == nil { - db.mu.Lock() - db.outConn[conn] = true - db.mu.Unlock() + + ci, err := db.driver.Open(db.dsn) + if err != nil { + return nil, err + } + dc := &driverConn{ + db: db, + ci: ci, } - return conn, err + db.mu.Lock() + db.addDepLocked(dc, dc) + dc.inUse = true + db.mu.Unlock() + return dc, nil } -// connIfFree returns (wanted, true) if wanted is still a valid conn and +var ( + errConnClosed = errors.New("database/sql: internal sentinel error: conn is closed") + errConnBusy = errors.New("database/sql: internal sentinel error: conn is busy") +) + +// connIfFree returns (wanted, nil) if wanted is still a valid conn and // isn't in use. // -// If wanted is valid but in use, connIfFree returns (wanted, false). -// If wanted is invalid, connIfFre returns (nil, false). -func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { +// The error is errConnClosed if the connection if the requested connection +// is invalid because it's been closed. +// +// The error is errConnBusy if the connection is in use. +func (db *DB) connIfFree(wanted *driverConn) (*driverConn, error) { db.mu.Lock() defer db.mu.Unlock() - if db.outConn[wanted] { - return conn, false + if wanted.inUse { + return nil, errConnBusy + } + if wanted.dbmuClosed { + return nil, errConnClosed } for i, conn := range db.freeConn { if conn != wanted { @@ -339,27 +523,36 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { } db.freeConn[i] = db.freeConn[len(db.freeConn)-1] db.freeConn = db.freeConn[:len(db.freeConn)-1] - db.outConn[wanted] = true - return wanted, true + wanted.inUse = true + return wanted, nil } - return nil, false + // TODO(bradfitz): shouldn't get here. After Go 1.1, change this to: + // panic("connIfFree call requested a non-closed, non-busy, non-free conn") + // Which passes all the tests, but I'm too paranoid to include this + // late in Go 1.1. + // Instead, treat it like a busy connection: + return nil, errConnBusy } // putConnHook is a hook for testing. -var putConnHook func(*DB, driver.Conn) +var putConnHook func(*DB, *driverConn) // noteUnusedDriverStatement notes that si is no longer used and should // be closed whenever possible (when c is next not in use), unless c is // already closed. -func (db *DB) noteUnusedDriverStatement(c driver.Conn, si driver.Stmt) { +func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { db.mu.Lock() defer db.mu.Unlock() - if db.outConn[c] { - db.onConnPut[c] = append(db.onConnPut[c], func() { + if c.inUse { + c.onPut = append(c.onPut, func() { si.Close() }) } else { - si.Close() + c.Lock() + defer c.Unlock() + if !c.finalClosed { + si.Close() + } } } @@ -369,25 +562,23 @@ const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. -func (db *DB) putConn(c driver.Conn, err error) { +func (db *DB) putConn(dc *driverConn, err error) { db.mu.Lock() - if !db.outConn[c] { + if !dc.inUse { if debugGetPut { - fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", c, stack(), db.lastPut[c]) + fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc]) } panic("sql: connection returned that was never out") } if debugGetPut { - db.lastPut[c] = stack() + db.lastPut[dc] = stack() } - delete(db.outConn, c) + dc.inUse = false - if fns, ok := db.onConnPut[c]; ok { - for _, fn := range fns { - fn() - } - delete(db.onConnPut, c) + for _, fn := range dc.onPut { + fn() } + dc.onPut = nil if err == driver.ErrBadConn { // Don't reuse bad connections. @@ -395,17 +586,16 @@ func (db *DB) putConn(c driver.Conn, err error) { return } if putConnHook != nil { - putConnHook(db, c) + putConnHook(db, dc) } - if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() { - db.freeConn = append(db.freeConn, c) + if n := len(db.freeConn); !db.closed && n < db.maxIdleConnsLocked() { + db.freeConn = append(db.freeConn, dc) db.mu.Unlock() return } - // TODO: check to see if we need this Conn for any prepared - // statements which are still active? db.mu.Unlock() - c.Close() + + dc.Close() } // Prepare creates a prepared statement for later queries or executions. @@ -430,21 +620,24 @@ func (db *DB) prepare(query string) (*Stmt, error) { // to a connection, and to execute this prepared statement // we either need to use this connection (if it's free), else // get a new connection + re-prepare + execute on that one. - ci, err := db.conn() + dc, err := db.conn() if err != nil { return nil, err } - si, err := ci.Prepare(query) + dc.Lock() + si, err := dc.prepareLocked(query) + dc.Unlock() if err != nil { - db.putConn(ci, err) + db.putConn(dc, err) return nil, err } stmt := &Stmt{ db: db, query: query, - css: []connStmt{{ci, si}}, + css: []connStmt{{dc, si}}, } db.addDep(stmt, stmt) + db.putConn(dc, nil) return stmt, nil } @@ -463,35 +656,38 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { } func (db *DB) exec(query string, args []interface{}) (res Result, err error) { - ci, err := db.conn() + dc, err := db.conn() if err != nil { return nil, err } defer func() { - db.putConn(ci, err) + db.putConn(dc, err) }() - if execer, ok := ci.(driver.Execer); ok { + if execer, ok := dc.ci.(driver.Execer); ok { dargs, err := driverArgs(nil, args) if err != nil { return nil, err } + dc.Lock() resi, err := execer.Exec(query, dargs) + dc.Unlock() if err != driver.ErrSkip { if err != nil { return nil, err } - return result{resi}, nil + return driverResult{dc, resi}, nil } } - sti, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { return nil, err } - defer sti.Close() - - return resultFromStatement(sti, args...) + defer withLock(dc, func() { si.Close() }) + return resultFromStatement(driverStmt{dc, si}, args...) } // Query executes a query that returns rows, typically a SELECT. @@ -521,24 +717,25 @@ func (db *DB) query(query string, args []interface{}) (*Rows, error) { // queryConn executes a query on the given connection. // The connection gets released by the releaseConn function. -func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { - if queryer, ok := ci.(driver.Queryer); ok { +func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { + if queryer, ok := dc.ci.(driver.Queryer); ok { dargs, err := driverArgs(nil, args) if err != nil { releaseConn(err) return nil, err } + dc.Lock() rowsi, err := queryer.Query(query, dargs) + dc.Unlock() if err != driver.ErrSkip { if err != nil { releaseConn(err) return nil, err } - // Note: ownership of ci passes to the *Rows, to be freed + // Note: ownership of dc passes to the *Rows, to be freed // with releaseConn. rows := &Rows{ - db: db, - ci: ci, + dc: dc, releaseConn: releaseConn, rowsi: rowsi, } @@ -546,27 +743,31 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a } } - sti, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { releaseConn(err) return nil, err } - rowsi, err := rowsiFromStatement(sti, args...) + ds := driverStmt{dc, si} + rowsi, err := rowsiFromStatement(ds, args...) if err != nil { releaseConn(err) - sti.Close() + dc.Lock() + si.Close() + dc.Unlock() return nil, err } // Note: ownership of ci passes to the *Rows, to be freed // with releaseConn. rows := &Rows{ - db: db, - ci: ci, + dc: dc, releaseConn: releaseConn, rowsi: rowsi, - closeStmt: sti, + closeStmt: si, } return rows, nil } @@ -594,18 +795,20 @@ func (db *DB) Begin() (*Tx, error) { } func (db *DB) begin() (tx *Tx, err error) { - ci, err := db.conn() + dc, err := db.conn() if err != nil { return nil, err } - txi, err := ci.Begin() + dc.Lock() + txi, err := dc.ci.Begin() + dc.Unlock() if err != nil { - db.putConn(ci, err) + db.putConn(dc, err) return nil, err } return &Tx{ db: db, - ci: ci, + dc: dc, txi: txi, }, nil } @@ -624,15 +827,11 @@ func (db *DB) Driver() driver.Driver { type Tx struct { db *DB - // ci is owned exclusively until Commit or Rollback, at which point + // dc is owned exclusively until Commit or Rollback, at which point // it's returned with putConn. - ci driver.Conn + dc *driverConn txi driver.Tx - // cimu is held while somebody is using ci (between grabConn - // and releaseConn) - cimu sync.Mutex - // done transitions from false to true exactly once, on Commit // or Rollback. once done, all operations fail with // ErrTxDone. @@ -646,21 +845,16 @@ func (tx *Tx) close() { panic("double close") // internal error } tx.done = true - tx.db.putConn(tx.ci, nil) - tx.ci = nil + tx.db.putConn(tx.dc, nil) + tx.dc = nil tx.txi = nil } -func (tx *Tx) grabConn() (driver.Conn, error) { +func (tx *Tx) grabConn() (*driverConn, error) { if tx.done { return nil, ErrTxDone } - tx.cimu.Lock() - return tx.ci, nil -} - -func (tx *Tx) releaseConn() { - tx.cimu.Unlock() + return tx.dc, nil } // Commit commits the transaction. @@ -669,6 +863,8 @@ func (tx *Tx) Commit() error { return ErrTxDone } defer tx.close() + tx.dc.Lock() + defer tx.dc.Unlock() return tx.txi.Commit() } @@ -678,6 +874,8 @@ func (tx *Tx) Rollback() error { return ErrTxDone } defer tx.close() + tx.dc.Lock() + defer tx.dc.Unlock() return tx.txi.Rollback() } @@ -701,21 +899,25 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // Perhaps just looking at the reference count (by noting // Stmt.Close) would be enough. We might also want a finalizer // on Stmt to drop the reference count. - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return nil, err } - defer tx.releaseConn() - si, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { return nil, err } stmt := &Stmt{ - db: tx.db, - tx: tx, - txsi: si, + db: tx.db, + tx: tx, + txsi: &driverStmt{ + Locker: dc, + si: si, + }, query: query, } return stmt, nil @@ -739,16 +941,20 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return &Stmt{stickyErr: err} } - defer tx.releaseConn() - si, err := ci.Prepare(stmt.query) + dc.Lock() + si, err := dc.ci.Prepare(stmt.query) + dc.Unlock() return &Stmt{ - db: tx.db, - tx: tx, - txsi: si, + db: tx.db, + tx: tx, + txsi: &driverStmt{ + Locker: dc, + si: si, + }, query: stmt.query, stickyErr: err, } @@ -757,45 +963,46 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { // Exec executes a query that doesn't return rows. // For example: an INSERT and UPDATE. func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return nil, err } - defer tx.releaseConn() - if execer, ok := ci.(driver.Execer); ok { + if execer, ok := dc.ci.(driver.Execer); ok { dargs, err := driverArgs(nil, args) if err != nil { return nil, err } + dc.Lock() resi, err := execer.Exec(query, dargs) + dc.Unlock() if err == nil { - return result{resi}, nil + return driverResult{dc, resi}, nil } if err != driver.ErrSkip { return nil, err } } - sti, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { return nil, err } - defer sti.Close() + defer withLock(dc, func() { si.Close() }) - return resultFromStatement(sti, args...) + return resultFromStatement(driverStmt{dc, si}, args...) } // Query executes a query that returns rows, typically a SELECT. func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return nil, err } - - releaseConn := func(err error) { tx.releaseConn() } - - return tx.db.queryConn(ci, releaseConn, query, args) + releaseConn := func(error) {} + return tx.db.queryConn(dc, releaseConn, query, args) } // QueryRow executes a query that is expected to return at most one row. @@ -808,7 +1015,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { // connStmt is a prepared statement on a particular connection. type connStmt struct { - ci driver.Conn + dc *driverConn si driver.Stmt } @@ -823,7 +1030,7 @@ type Stmt struct { // If in a transaction, else both nil: tx *Tx - txsi driver.Stmt + txsi *driverStmt mu sync.Mutex // protects the rest of the fields closed bool @@ -840,39 +1047,45 @@ type Stmt struct { func (s *Stmt) Exec(args ...interface{}) (Result, error) { s.closemu.RLock() defer s.closemu.RUnlock() - _, releaseConn, si, err := s.connStmt() + dc, releaseConn, si, err := s.connStmt() if err != nil { return nil, err } defer releaseConn(nil) - return resultFromStatement(si, args...) + return resultFromStatement(driverStmt{dc, si}, args...) } -func resultFromStatement(si driver.Stmt, args ...interface{}) (Result, error) { +func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { + ds.Lock() + want := ds.si.NumInput() + ds.Unlock() + // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the // driver deal with errors. - if want := si.NumInput(); want != -1 && len(args) != want { + if want != -1 && len(args) != want { return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) } - dargs, err := driverArgs(si, args) + dargs, err := driverArgs(&ds, args) if err != nil { return nil, err } - resi, err := si.Exec(dargs) + ds.Lock() + resi, err := ds.si.Exec(dargs) + ds.Unlock() if err != nil { return nil, err } - return result{resi}, nil + return driverResult{ds.Locker, resi}, nil } // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.Stmt, err error) { +func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { if err = s.stickyErr; err != nil { return } @@ -891,19 +1104,27 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St if err != nil { return } - releaseConn = func(error) { s.tx.releaseConn() } - return ci, releaseConn, s.txsi, nil + releaseConn = func(error) {} + return ci, releaseConn, s.txsi.si, nil } var cs connStmt match := false - for _, v := range s.css { - // TODO(bradfitz): lazily clean up entries in this - // list with dead conns while enumerating - if _, match = s.db.connIfFree(v.ci); match { + for i := 0; i < len(s.css); i++ { + v := s.css[i] + _, err := s.db.connIfFree(v.dc) + if err == nil { + match = true cs = v break } + if err == errConnClosed { + // Lazily remove dead conn from our freelist. + s.css[i] = s.css[len(s.css)-1] + s.css = s.css[:len(s.css)-1] + i-- + } + } s.mu.Unlock() @@ -911,11 +1132,13 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St // TODO(bradfitz): or wait for one? make configurable later? if !match { for i := 0; ; i++ { - ci, err := s.db.conn() + dc, err := s.db.conn() if err != nil { return nil, nil, nil, err } - si, err := ci.Prepare(s.query) + dc.Lock() + si, err := dc.prepareLocked(s.query) + dc.Unlock() if err == driver.ErrBadConn && i < 10 { continue } @@ -923,14 +1146,14 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St return nil, nil, nil, err } s.mu.Lock() - cs = connStmt{ci, si} + cs = connStmt{dc, si} s.css = append(s.css, cs) s.mu.Unlock() break } } - conn := cs.ci + conn := cs.dc releaseConn = func(err error) { s.db.putConn(conn, err) } return conn, releaseConn, cs.si, nil } @@ -941,12 +1164,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { s.closemu.RLock() defer s.closemu.RUnlock() - ci, releaseConn, si, err := s.connStmt() + dc, releaseConn, si, err := s.connStmt() if err != nil { return nil, err } - rowsi, err := rowsiFromStatement(si, args...) + ds := driverStmt{dc, si} + rowsi, err := rowsiFromStatement(ds, args...) if err != nil { releaseConn(err) return nil, err @@ -955,8 +1179,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { // Note: ownership of ci passes to the *Rows, to be freed // with releaseConn. rows := &Rows{ - db: s.db, - ci: ci, + dc: dc, rowsi: rowsi, // releaseConn set below } @@ -968,20 +1191,26 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return rows, nil } -func rowsiFromStatement(si driver.Stmt, args ...interface{}) (driver.Rows, error) { +func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) { + ds.Lock() + want := ds.si.NumInput() + ds.Unlock() + // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the // driver deal with errors. - if want := si.NumInput(); want != -1 && len(args) != want { - return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", si.NumInput(), len(args)) + if want != -1 && len(args) != want { + return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(si, args) + dargs, err := driverArgs(&ds, args) if err != nil { return nil, err } - rowsi, err := si.Query(dargs) + ds.Lock() + rowsi, err := ds.si.Query(dargs) + ds.Unlock() if err != nil { return nil, err } @@ -1032,7 +1261,9 @@ func (s *Stmt) Close() error { func (s *Stmt) finalClose() error { for _, v := range s.css { - s.db.noteUnusedDriverStatement(v.ci, v.si) + s.db.noteUnusedDriverStatement(v.dc, v.si) + v.dc.removeOpenStmt(v.si) + s.db.removeDep(v.dc, s) } s.css = nil return nil @@ -1052,8 +1283,7 @@ func (s *Stmt) finalClose() error { // err = rows.Err() // get any error encountered during iteration // ... type Rows struct { - db *DB - ci driver.Conn // owned; must call releaseConn when closed to release + dc *driverConn // owned; must call releaseConn when closed to release releaseConn func(error) rowsi driver.Rows @@ -1136,24 +1366,6 @@ func (rs *Rows) Scan(dest ...interface{}) error { return fmt.Errorf("sql: Scan error on column index %d: %v", i, err) } } - for _, dp := range dest { - b, ok := dp.(*[]byte) - if !ok { - continue - } - if *b == nil { - // If the []byte is now nil (for a NULL value), - // don't fall through to below which would - // turn it into a non-nil 0-length byte slice - continue - } - if _, ok = dp.(*RawBytes); ok { - continue - } - clone := make([]byte, len(*b)) - copy(clone, *b) - *b = clone - } return nil } @@ -1166,10 +1378,10 @@ func (rs *Rows) Close() error { } rs.closed = true err := rs.rowsi.Close() - rs.releaseConn(err) if rs.closeStmt != nil { rs.closeStmt.Close() } + rs.releaseConn(err) return err } @@ -1226,11 +1438,31 @@ type Result interface { RowsAffected() (int64, error) } -type result struct { - driver.Result +type driverResult struct { + sync.Locker // the *driverConn + resi driver.Result +} + +func (dr driverResult) LastInsertId() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.LastInsertId() +} + +func (dr driverResult) RowsAffected() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.RowsAffected() } func stack() string { - var buf [1024]byte + var buf [2 << 10]byte return string(buf[:runtime.Stack(buf[:], false)]) } + +// withLock runs while holding lk. +func withLock(lk sync.Locker, fn func()) { + lk.Lock() + fn() + lk.Unlock() +} diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go index 53b229600..e6cc667fa 100644 --- a/src/pkg/database/sql/sql_test.go +++ b/src/pkg/database/sql/sql_test.go @@ -5,10 +5,11 @@ package sql import ( - "database/sql/driver" "fmt" "reflect" + "runtime" "strings" + "sync" "testing" "time" ) @@ -16,10 +17,10 @@ import ( func init() { type dbConn struct { db *DB - c driver.Conn + c *driverConn } freedFrom := make(map[dbConn]string) - putConnHook = func(db *DB, c driver.Conn) { + putConnHook = func(db *DB, c *driverConn) { for _, oc := range db.freeConn { if oc == c { // print before panic, as panic may get lost due to conflicting panic @@ -37,7 +38,15 @@ const fakeDBName = "foo" var chrisBirthday = time.Unix(123456789, 0) -func newTestDB(t *testing.T, name string) *DB { +type testOrBench interface { + Fatalf(string, ...interface{}) + Errorf(string, ...interface{}) + Fatal(...interface{}) + Error(...interface{}) + Logf(string, ...interface{}) +} + +func newTestDB(t testOrBench, name string) *DB { db, err := Open("test", fakeDBName) if err != nil { t.Fatalf("Open: %v", err) @@ -51,21 +60,42 @@ func newTestDB(t *testing.T, name string) *DB { exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2) exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) } + if name == "magicquery" { + // Magic table name and column, known by fakedb_test.go. + exec(t, db, "CREATE|magicquery|op=string,millis=int32") + exec(t, db, "INSERT|magicquery|op=sleep,millis=10") + } return db } -func exec(t *testing.T, db *DB, query string, args ...interface{}) { +func exec(t testOrBench, db *DB, query string, args ...interface{}) { _, err := db.Exec(query, args...) if err != nil { t.Fatalf("Exec of %q: %v", query, err) } } -func closeDB(t *testing.T, db *DB) { +func closeDB(t testOrBench, db *DB) { if e := recover(); e != nil { fmt.Printf("Panic: %v\n", e) panic(e) } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + for i, dc := range db.freeConn { + if n := len(dc.openStmt); n > 0 { + // Just a sanity check. This is legal in + // general, but if we make the tests clean up + // their statements first, then we can safely + // verify this is always zero here, and any + // other value is a leak. + t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n) + } + } err := db.Close() if err != nil { t.Fatalf("error closing DB: %v", err) @@ -78,7 +108,52 @@ func numPrepares(t *testing.T, db *DB) int { if n := len(db.freeConn); n != 1 { t.Fatalf("free conns = %d; want 1", n) } - return db.freeConn[0].(*fakeConn).numPrepare + return db.freeConn[0].ci.(*fakeConn).numPrepare +} + +func (db *DB) numDeps() int { + db.mu.Lock() + defer db.mu.Unlock() + return len(db.dep) +} + +// Dependencies are closed via a goroutine, so this polls waiting for +// numDeps to fall to want, waiting up to d. +func (db *DB) numDepsPollUntil(want int, d time.Duration) int { + deadline := time.Now().Add(d) + for { + n := db.numDeps() + if n <= want || time.Now().After(deadline) { + return n + } + time.Sleep(50 * time.Millisecond) + } +} + +func (db *DB) numFreeConns() int { + db.mu.Lock() + defer db.mu.Unlock() + return len(db.freeConn) +} + +func (db *DB) dumpDeps(t *testing.T) { + for fc := range db.dep { + db.dumpDep(t, 0, fc, map[finalCloser]bool{}) + } +} + +func (db *DB) dumpDep(t *testing.T, depth int, dep finalCloser, seen map[finalCloser]bool) { + seen[dep] = true + indent := strings.Repeat(" ", depth) + ds := db.dep[dep] + for k := range ds { + t.Logf("%s%T (%p) waiting for -> %T (%p)", indent, dep, dep, k, k) + if fc, ok := k.(finalCloser); ok { + if !seen[fc] { + db.dumpDep(t, depth+1, fc, seen) + } + } + } } func TestQuery(t *testing.T) { @@ -117,7 +192,7 @@ func TestQuery(t *testing.T) { // And verify that the final rows.Next() call, which hit EOF, // also closed the rows connection. - if n := len(db.freeConn); n != 1 { + if n := db.numFreeConns(); n != 1 { t.Fatalf("free conns after query hitting EOF = %d; want 1", n) } if prepares := numPrepares(t, db) - prepares0; prepares != 1 { @@ -576,7 +651,7 @@ func TestQueryRowClosingStmt(t *testing.T) { if len(db.freeConn) != 1 { t.Fatalf("expected 1 free conn") } - fakeConn := db.freeConn[0].(*fakeConn) + fakeConn := db.freeConn[0].ci.(*fakeConn) if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed { t.Errorf("statement close mismatch: made %d, closed %d", made, closed) } @@ -708,3 +783,326 @@ func TestQueryRowNilScanDest(t *testing.T) { t.Errorf("error = %q; want %q", err.Error(), want) } } + +func TestIssue4902(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + driver := db.driver.(*fakeDriver) + opens0 := driver.openCount + + var stmt *Stmt + var err error + for i := 0; i < 10; i++ { + stmt, err = db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + } + + opens := driver.openCount - opens0 + if opens > 1 { + t.Errorf("opens = %d; want <= 1", opens) + t.Logf("db = %#v", db) + t.Logf("driver = %#v", driver) + t.Logf("stmt = %#v", stmt) + } +} + +// Issue 3857 +// This used to deadlock. +func TestSimultaneousQueries(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + r1, err := tx.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer r1.Close() + + r2, err := tx.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer r2.Close() +} + +func TestMaxIdleConns(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + tx.Commit() + if got := len(db.freeConn); got != 1 { + t.Errorf("freeConns = %d; want 1", got) + } + + db.SetMaxIdleConns(0) + + if got := len(db.freeConn); got != 0 { + t.Errorf("freeConns after set to zero = %d; want 0", got) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + tx.Commit() + if got := len(db.freeConn); got != 0 { + t.Errorf("freeConns = %d; want 0", got) + } +} + +// golang.org/issue/5323 +func TestStmtCloseDeps(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + driver := db.driver.(*fakeDriver) + + driver.mu.Lock() + opens0 := driver.openCount + closes0 := driver.closeCount + driver.mu.Unlock() + openDelta0 := opens0 - closes0 + + stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?") + if err != nil { + t.Fatal(err) + } + + // Start 50 parallel slow queries. + const ( + nquery = 50 + sleepMillis = 25 + nbatch = 2 + ) + var wg sync.WaitGroup + for batch := 0; batch < nbatch; batch++ { + for i := 0; i < nquery; i++ { + wg.Add(1) + go func() { + defer wg.Done() + var op string + if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows { + t.Error(err) + } + }() + } + // Sleep for twice the expected length of time for the + // batch of 50 queries above to finish before starting + // the next round. + time.Sleep(2 * sleepMillis * time.Millisecond) + } + wg.Wait() + + if g, w := db.numFreeConns(), 2; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(4, time.Second); n > 4 { + t.Errorf("number of dependencies = %d; expected <= 4", n) + db.dumpDeps(t) + } + + driver.mu.Lock() + opens := driver.openCount - opens0 + closes := driver.closeCount - closes0 + driver.mu.Unlock() + openDelta := (driver.openCount - driver.closeCount) - openDelta0 + + if openDelta > 2 { + t.Logf("open calls = %d", opens) + t.Logf("close calls = %d", closes) + t.Logf("open delta = %d", openDelta) + t.Errorf("db connections opened = %d; want <= 2", openDelta) + db.dumpDeps(t) + } + + if len(stmt.css) > nquery { + t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery) + } + + if err := stmt.Close(); err != nil { + t.Fatal(err) + } + + if g, w := db.numFreeConns(), 2; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(2, time.Second); n > 2 { + t.Errorf("number of dependencies = %d; expected <= 2", n) + db.dumpDeps(t) + } + + db.SetMaxIdleConns(0) + + if g, w := db.numFreeConns(), 0; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(0, time.Second); n > 0 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } +} + +// golang.org/issue/5046 +func TestCloseConnBeforeStmts(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v; from %s", err, stack()) + db.dumpDeps(t) + t.Errorf("DB = %#v", db) + } + }) + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + + if len(db.freeConn) != 1 { + t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn)) + } + dc := db.freeConn[0] + if dc.closed { + t.Errorf("conn shouldn't be closed") + } + + if n := len(dc.openStmt); n != 1 { + t.Errorf("driverConn num openStmt = %d; want 1", n) + } + err = db.Close() + if err != nil { + t.Errorf("db Close = %v", err) + } + if !dc.closed { + t.Errorf("after db.Close, driverConn should be closed") + } + if n := len(dc.openStmt); n != 0 { + t.Errorf("driverConn num openStmt = %d; want 0", n) + } + + err = stmt.Close() + if err != nil { + t.Errorf("Stmt close = %v", err) + } + + if !dc.closed { + t.Errorf("conn should be closed") + } + if dc.ci != nil { + t.Errorf("after Stmt Close, driverConn's Conn interface should be nil") + } +} + +// golang.org/issue/5283: don't release the Rows' connection in Close +// before calling Stmt.Close. +func TestRowsCloseOrder(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxIdleConns(0) + setStrictFakeConnClose(t) + defer setStrictFakeConnClose(nil) + + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + err = rows.Close() + if err != nil { + t.Fatal(err) + } +} + +func manyConcurrentQueries(t testOrBench) { + maxProcs, numReqs := 16, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + + db := newTestDB(t, "people") + defer closeDB(t, db) + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + var wg sync.WaitGroup + wg.Add(numReqs) + + reqs := make(chan bool) + defer close(reqs) + + for i := 0; i < maxProcs*2; i++ { + go func() { + for _ = range reqs { + rows, err := stmt.Query() + if err != nil { + t.Errorf("error on query: %v", err) + wg.Done() + continue + } + + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + + wg.Done() + } + }() + } + + for i := 0; i < numReqs; i++ { + reqs <- true + } + + wg.Wait() +} + +func TestConcurrency(t *testing.T) { + manyConcurrentQueries(t) +} + +func BenchmarkConcurrency(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + manyConcurrentQueries(b) + } +} |