diff options
Diffstat (limited to 'src/database')
-rw-r--r-- | src/database/sql/convert.go | 299 | ||||
-rw-r--r-- | src/database/sql/convert_test.go | 348 | ||||
-rw-r--r-- | src/database/sql/doc.txt | 46 | ||||
-rw-r--r-- | src/database/sql/driver/driver.go | 211 | ||||
-rw-r--r-- | src/database/sql/driver/types.go | 252 | ||||
-rw-r--r-- | src/database/sql/driver/types_test.go | 65 | ||||
-rw-r--r-- | src/database/sql/example_test.go | 46 | ||||
-rw-r--r-- | src/database/sql/fakedb_test.go | 829 | ||||
-rw-r--r-- | src/database/sql/sql.go | 1770 | ||||
-rw-r--r-- | src/database/sql/sql_test.go | 1987 |
10 files changed, 5853 insertions, 0 deletions
diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go new file mode 100644 index 000000000..c0b38a249 --- /dev/null +++ b/src/database/sql/convert.go @@ -0,0 +1,299 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Type conversions for Scan. + +package sql + +import ( + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" +) + +var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +// driverArgs converts arguments from callers of Stmt.Exec and +// Stmt.Query into driver Values. +// +// The statement ds may be nil, if no statement is available. +func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) { + dargs := make([]driver.Value, len(args)) + var si driver.Stmt + if ds != nil { + si = ds.si + } + cc, ok := si.(driver.ColumnConverter) + + // Normal path, for a driver.Stmt that is not a ColumnConverter. + if !ok { + for n, arg := range args { + var err error + dargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { + return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err) + } + } + return dargs, nil + } + + // Let the Stmt convert its own arguments. + for n, arg := range args { + // First, see if the value itself knows how to convert + // itself to a driver type. For example, a NullString + // struct changing into a string or nil. + if svi, ok := arg.(driver.Valuer); ok { + sv, err := svi.Value() + if err != nil { + return nil, fmt.Errorf("sql: argument index %d from Value: %v", n, err) + } + if !driver.IsValue(sv) { + return nil, fmt.Errorf("sql: argument index %d: non-subset type %T returned from Value", n, sv) + } + arg = sv + } + + // Second, ask the column to sanity check itself. For + // example, drivers might use this to make sure that + // an int64 values being inserted into a 16-bit + // integer field is in range (before getting + // truncated), or that a nil can't go into a NOT NULL + // column before going across the network to get the + // same error. + var err error + ds.Lock() + dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg) + ds.Unlock() + if err != nil { + return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err) + } + if !driver.IsValue(dargs[n]) { + return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T", + arg, dargs[n]) + } + } + + return dargs, nil +} + +// convertAssign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func convertAssign(dest, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *interface{}: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case nil: + switch d := dest.(type) { + case *interface{}: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = nil + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *RawBytes: + sv = reflect.ValueOf(src) + if b, ok := asBytes([]byte(*d)[:0], sv); ok { + *d = RawBytes(b) + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *interface{}: + *d = src + return nil + } + + if scanner, ok := dest.(Scanner); ok { + return scanner.Scan(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Ptr { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if dv.Kind() == sv.Kind() { + dv.Set(sv) + return nil + } + + switch dv.Kind() { + case reflect.Ptr: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } else { + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssign(dv.Interface(), src) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + } + + return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } else { + c := make([]byte, len(b)) + copy(c, b) + return c + } +} + +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} diff --git a/src/database/sql/convert_test.go b/src/database/sql/convert_test.go new file mode 100644 index 000000000..98af9fb64 --- /dev/null +++ b/src/database/sql/convert_test.go @@ -0,0 +1,348 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "database/sql/driver" + "fmt" + "reflect" + "runtime" + "testing" + "time" +) + +var someTime = time.Unix(123, 0) +var answer int64 = 42 + +type conversionTest struct { + s, d interface{} // source and destination + + // following are used if they're non-zero + wantint int64 + wantuint uint64 + wantstr string + wantbytes []byte + wantraw RawBytes + wantf32 float32 + wantf64 float64 + wanttime time.Time + wantbool bool // used if d is of type *bool + wanterr string + wantiface interface{} + wantptr *int64 // if non-nil, *d's pointed value must be equal to *wantptr + wantnil bool // if true, *d must be *int64(nil) +} + +// Target variables for scanning into. +var ( + scanstr string + scanbytes []byte + scanraw RawBytes + scanint int + scanint8 int8 + scanint16 int16 + scanint32 int32 + scanuint8 uint8 + scanuint16 uint16 + scanbool bool + scanf32 float32 + scanf64 float64 + scantime time.Time + scanptr *int64 + scaniface interface{} +) + +var conversionTests = []conversionTest{ + // Exact conversions (destination pointer type matches source type) + {s: "foo", d: &scanstr, wantstr: "foo"}, + {s: 123, d: &scanint, wantint: 123}, + {s: someTime, d: &scantime, wanttime: someTime}, + + // To strings + {s: "string", d: &scanstr, wantstr: "string"}, + {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, + {s: 123, d: &scanstr, wantstr: "123"}, + {s: int8(123), d: &scanstr, wantstr: "123"}, + {s: int64(123), d: &scanstr, wantstr: "123"}, + {s: uint8(123), d: &scanstr, wantstr: "123"}, + {s: uint16(123), d: &scanstr, wantstr: "123"}, + {s: uint32(123), d: &scanstr, wantstr: "123"}, + {s: uint64(123), d: &scanstr, wantstr: "123"}, + {s: 1.5, d: &scanstr, wantstr: "1.5"}, + + // To []byte + {s: nil, d: &scanbytes, wantbytes: nil}, + {s: "string", d: &scanbytes, wantbytes: []byte("string")}, + {s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")}, + {s: 123, d: &scanbytes, wantbytes: []byte("123")}, + {s: int8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: int64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint16(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint32(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")}, + + // To RawBytes + {s: nil, d: &scanraw, wantraw: nil}, + {s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")}, + {s: 123, d: &scanraw, wantraw: RawBytes("123")}, + {s: int8(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: int64(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint8(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint16(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint32(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint64(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")}, + + // Strings to integers + {s: "255", d: &scanuint8, wantuint: 255}, + {s: "256", d: &scanuint8, wanterr: `converting string "256" to a uint8: strconv.ParseUint: parsing "256": value out of range`}, + {s: "256", d: &scanuint16, wantuint: 256}, + {s: "-1", d: &scanint, wantint: -1}, + {s: "foo", d: &scanint, wanterr: `converting string "foo" to a int: strconv.ParseInt: parsing "foo": invalid syntax`}, + + // True bools + {s: true, d: &scanbool, wantbool: true}, + {s: "True", d: &scanbool, wantbool: true}, + {s: "TRUE", d: &scanbool, wantbool: true}, + {s: "1", d: &scanbool, wantbool: true}, + {s: 1, d: &scanbool, wantbool: true}, + {s: int64(1), d: &scanbool, wantbool: true}, + {s: uint16(1), d: &scanbool, wantbool: true}, + + // False bools + {s: false, d: &scanbool, wantbool: false}, + {s: "false", d: &scanbool, wantbool: false}, + {s: "FALSE", d: &scanbool, wantbool: false}, + {s: "0", d: &scanbool, wantbool: false}, + {s: 0, d: &scanbool, wantbool: false}, + {s: int64(0), d: &scanbool, wantbool: false}, + {s: uint16(0), d: &scanbool, wantbool: false}, + + // Not bools + {s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`}, + {s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`}, + + // Floats + {s: float64(1.5), d: &scanf64, wantf64: float64(1.5)}, + {s: int64(1), d: &scanf64, wantf64: float64(1)}, + {s: float64(1.5), d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf64, wantf64: float64(1.5)}, + + // Pointers + {s: interface{}(nil), d: &scanptr, wantnil: true}, + {s: int64(42), d: &scanptr, wantptr: &answer}, + + // To interface{} + {s: float64(1.5), d: &scaniface, wantiface: float64(1.5)}, + {s: int64(1), d: &scaniface, wantiface: int64(1)}, + {s: "str", d: &scaniface, wantiface: "str"}, + {s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")}, + {s: true, d: &scaniface, wantiface: true}, + {s: nil, d: &scaniface}, + {s: []byte(nil), d: &scaniface, wantiface: []byte(nil)}, +} + +func intPtrValue(intptr interface{}) interface{} { + return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int() +} + +func intValue(intptr interface{}) int64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Int() +} + +func uintValue(intptr interface{}) uint64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Uint() +} + +func float64Value(ptr interface{}) float64 { + return *(ptr.(*float64)) +} + +func float32Value(ptr interface{}) float32 { + return *(ptr.(*float32)) +} + +func timeValue(ptr interface{}) time.Time { + return *(ptr.(*time.Time)) +} + +func TestConversions(t *testing.T) { + for n, ct := range conversionTests { + err := convertAssign(ct.d, ct.s) + errstr := "" + if err != nil { + errstr = err.Error() + } + errf := func(format string, args ...interface{}) { + base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d) + t.Errorf(base+format, args...) + } + if errstr != ct.wanterr { + errf("got error %q, want error %q", errstr, ct.wanterr) + } + if ct.wantstr != "" && ct.wantstr != scanstr { + errf("want string %q, got %q", ct.wantstr, scanstr) + } + if ct.wantint != 0 && ct.wantint != intValue(ct.d) { + errf("want int %d, got %d", ct.wantint, intValue(ct.d)) + } + if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) { + errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d)) + } + if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d)) + } + if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d)) + } + if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" { + errf("want bool %v, got %v", ct.wantbool, *bp) + } + if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) { + errf("want time %v, got %v", ct.wanttime, timeValue(ct.d)) + } + if ct.wantnil && *ct.d.(**int64) != nil { + errf("want nil, got %v", intPtrValue(ct.d)) + } + if ct.wantptr != nil { + if *ct.d.(**int64) == nil { + errf("want pointer to %v, got nil", *ct.wantptr) + } else if *ct.wantptr != intPtrValue(ct.d) { + errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d)) + } + } + if ifptr, ok := ct.d.(*interface{}); ok { + if !reflect.DeepEqual(ct.wantiface, scaniface) { + errf("want interface %#v, got %#v", ct.wantiface, scaniface) + continue + } + if srcBytes, ok := ct.s.([]byte); ok { + dstBytes := (*ifptr).([]byte) + if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] { + errf("copy into interface{} didn't copy []byte data") + } + } + } + } +} + +func TestNullString(t *testing.T) { + var ns NullString + convertAssign(&ns, []byte("foo")) + if !ns.Valid { + t.Errorf("expecting not null") + } + if ns.String != "foo" { + t.Errorf("expecting foo; got %q", ns.String) + } + convertAssign(&ns, nil) + if ns.Valid { + t.Errorf("expecting null on nil") + } + if ns.String != "" { + t.Errorf("expecting blank on nil; got %q", ns.String) + } +} + +type valueConverterTest struct { + c driver.ValueConverter + in, out interface{} + err string +} + +var valueConverterTests = []valueConverterTest{ + {driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""}, + {driver.DefaultParameterConverter, NullString{"", false}, nil, ""}, +} + +func TestValueConverters(t *testing.T) { + for i, tt := range valueConverterTests { + out, err := tt.c.ConvertValue(tt.in) + goterr := "" + if err != nil { + goterr = err.Error() + } + if goterr != tt.err { + t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q", + i, tt.c, tt.in, tt.in, goterr, tt.err) + } + if tt.err != "" { + continue + } + if !reflect.DeepEqual(out, tt.out) { + t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)", + i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out) + } + } +} + +// Tests that assigning to RawBytes doesn't allocate (and also works). +func TestRawBytesAllocs(t *testing.T) { + var tests = []struct { + name string + in interface{} + want string + }{ + {"uint64", uint64(12345678), "12345678"}, + {"uint32", uint32(1234), "1234"}, + {"uint16", uint16(12), "12"}, + {"uint8", uint8(1), "1"}, + {"uint", uint(123), "123"}, + {"int", int(123), "123"}, + {"int8", int8(1), "1"}, + {"int16", int16(12), "12"}, + {"int32", int32(1234), "1234"}, + {"int64", int64(12345678), "12345678"}, + {"float32", float32(1.5), "1.5"}, + {"float64", float64(64), "64"}, + {"bool", false, "false"}, + } + + buf := make(RawBytes, 10) + test := func(name string, in interface{}, want string) { + if err := convertAssign(&buf, in); err != nil { + t.Fatalf("%s: convertAssign = %v", name, err) + } + match := len(buf) == len(want) + if match { + for i, b := range buf { + if want[i] != b { + match = false + break + } + } + } + if !match { + t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want)) + } + } + + n := testing.AllocsPerRun(100, func() { + for _, tt := range tests { + test(tt.name, tt.in, tt.want) + } + }) + + // The numbers below are only valid for 64-bit interface word sizes, + // and gc. With 32-bit words there are more convT2E allocs, and + // with gccgo, only pointers currently go in interface data. + // So only care on amd64 gc for now. + measureAllocs := runtime.GOARCH == "amd64" && runtime.Compiler == "gc" + + if n > 0.5 && measureAllocs { + t.Fatalf("allocs = %v; want 0", n) + } + + // This one involves a convT2E allocation, string -> interface{} + n = testing.AllocsPerRun(100, func() { + test("string", "foo", "foo") + }) + if n > 1.5 && measureAllocs { + t.Fatalf("allocs = %v; want max 1", n) + } +} diff --git a/src/database/sql/doc.txt b/src/database/sql/doc.txt new file mode 100644 index 000000000..405c5ed2a --- /dev/null +++ b/src/database/sql/doc.txt @@ -0,0 +1,46 @@ +Goals of the sql and sql/driver packages: + +* Provide a generic database API for a variety of SQL or SQL-like + databases. There currently exist Go libraries for SQLite, MySQL, + and Postgres, but all with a very different feel, and often + a non-Go-like feel. + +* Feel like Go. + +* Care mostly about the common cases. Common SQL should be portable. + SQL edge cases or db-specific extensions can be detected and + conditionally used by the application. It is a non-goal to care + about every particular db's extension or quirk. + +* Separate out the basic implementation of a database driver + (implementing the sql/driver interfaces) vs the implementation + of all the user-level types and convenience methods. + In a nutshell: + + User Code ---> sql package (concrete types) ---> sql/driver (interfaces) + Database Driver -> sql (to register) + sql/driver (implement interfaces) + +* Make type casting/conversions consistent between all drivers. To + achieve this, most of the conversions are done in the sql package, + not in each driver. The drivers then only have to deal with a + smaller set of types. + +* Be flexible with type conversions, but be paranoid about silent + truncation or other loss of precision. + +* Handle concurrency well. Users shouldn't need to care about the + database's per-connection thread safety issues (or lack thereof), + and shouldn't have to maintain their own free pools of connections. + The 'db' package should deal with that bookkeeping as needed. Given + an *sql.DB, it should be possible to share that instance between + multiple goroutines, without any extra synchronization. + +* Push complexity, where necessary, down into the sql+driver packages, + rather than exposing it to users. Said otherwise, the sql package + should expose an ideal database that's not finnicky about how it's + accessed, even if that's not true. + +* Provide optional interfaces in sql/driver for drivers to implement + for special cases or fastpaths. But the only party that knows about + those is the sql package. To user code, some stuff just might start + working or start working slightly faster. diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go new file mode 100644 index 000000000..eca25f29a --- /dev/null +++ b/src/database/sql/driver/driver.go @@ -0,0 +1,211 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package driver defines interfaces to be implemented by database +// drivers as used by package sql. +// +// Most code should use package sql. +package driver + +import "errors" + +// Value is a value that drivers must be able to handle. +// It is either nil or an instance of one of these types: +// +// int64 +// float64 +// bool +// []byte +// string [*] everywhere except from Rows.Next. +// time.Time +type Value interface{} + +// Driver is the interface that must be implemented by a database +// driver. +type Driver interface { + // Open returns a new connection to the database. + // The name is a string in a driver-specific format. + // + // Open may return a cached connection (one previously + // closed), but doing so is unnecessary; the sql package + // maintains a pool of idle connections for efficient re-use. + // + // The returned connection is only used by one goroutine at a + // time. + Open(name string) (Conn, error) +} + +// ErrSkip may be returned by some optional interfaces' methods to +// indicate at runtime that the fast path is unavailable and the sql +// package should continue as if the optional interface was not +// implemented. ErrSkip is only supported where explicitly +// documented. +var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented") + +// ErrBadConn should be returned by a driver to signal to the sql +// package that a driver.Conn is in a bad state (such as the server +// having earlier closed the connection) and the sql package should +// retry on a new connection. +// +// To prevent duplicate operations, ErrBadConn should NOT be returned +// if there's a possibility that the database server might have +// performed the operation. Even if the server sends back an error, +// you shouldn't return ErrBadConn. +var ErrBadConn = errors.New("driver: bad connection") + +// Execer is an optional interface that may be implemented by a Conn. +// +// If a Conn does not implement Execer, the sql package's DB.Exec will +// first prepare a query, execute the statement, and then close the +// statement. +// +// Exec may return ErrSkip. +type Execer interface { + Exec(query string, args []Value) (Result, error) +} + +// Queryer is an optional interface that may be implemented by a Conn. +// +// If a Conn does not implement Queryer, the sql package's DB.Query will +// first prepare a query, execute the statement, and then close the +// statement. +// +// Query may return ErrSkip. +type Queryer interface { + Query(query string, args []Value) (Rows, error) +} + +// Conn is a connection to a database. It is not used concurrently +// by multiple goroutines. +// +// Conn is assumed to be stateful. +type Conn interface { + // Prepare returns a prepared statement, bound to this connection. + Prepare(query string) (Stmt, error) + + // Close invalidates and potentially stops any current + // prepared statements and transactions, marking this + // connection as no longer in use. + // + // Because the sql package maintains a free pool of + // connections and only calls Close when there's a surplus of + // idle connections, it shouldn't be necessary for drivers to + // do their own connection caching. + Close() error + + // Begin starts and returns a new transaction. + Begin() (Tx, error) +} + +// Result is the result of a query execution. +type Result interface { + // LastInsertId returns the database's auto-generated ID + // after, for example, an INSERT into a table with primary + // key. + LastInsertId() (int64, error) + + // RowsAffected returns the number of rows affected by the + // query. + RowsAffected() (int64, error) +} + +// Stmt is a prepared statement. It is bound to a Conn and not +// used by multiple goroutines concurrently. +type Stmt interface { + // Close closes the statement. + // + // As of Go 1.1, a Stmt will not be closed if it's in use + // by any queries. + Close() error + + // NumInput returns the number of placeholder parameters. + // + // If NumInput returns >= 0, the sql package will sanity check + // argument counts from callers and return errors to the caller + // before the statement's Exec or Query methods are called. + // + // NumInput may also return -1, if the driver doesn't know + // its number of placeholders. In that case, the sql package + // will not sanity check Exec or Query argument counts. + NumInput() int + + // Exec executes a query that doesn't return rows, such + // as an INSERT or UPDATE. + Exec(args []Value) (Result, error) + + // Query executes a query that may return rows, such as a + // SELECT. + Query(args []Value) (Rows, error) +} + +// ColumnConverter may be optionally implemented by Stmt if the +// statement is aware of its own columns' types and can convert from +// any type to a driver Value. +type ColumnConverter interface { + // ColumnConverter returns a ValueConverter for the provided + // column index. If the type of a specific column isn't known + // or shouldn't be handled specially, DefaultValueConverter + // can be returned. + ColumnConverter(idx int) ValueConverter +} + +// Rows is an iterator over an executed query's results. +type Rows interface { + // Columns returns the names of the columns. The number of + // columns of the result is inferred from the length of the + // slice. If a particular column name isn't known, an empty + // string should be returned for that entry. + Columns() []string + + // Close closes the rows iterator. + Close() error + + // Next is called to populate the next row of data into + // the provided slice. The provided slice will be the same + // size as the Columns() are wide. + // + // The dest slice may be populated only with + // a driver Value type, but excluding string. + // All string values must be converted to []byte. + // + // Next should return io.EOF when there are no more rows. + Next(dest []Value) error +} + +// Tx is a transaction. +type Tx interface { + Commit() error + Rollback() error +} + +// RowsAffected implements Result for an INSERT or UPDATE operation +// which mutates a number of rows. +type RowsAffected int64 + +var _ Result = RowsAffected(0) + +func (RowsAffected) LastInsertId() (int64, error) { + return 0, errors.New("no LastInsertId available") +} + +func (v RowsAffected) RowsAffected() (int64, error) { + return int64(v), nil +} + +// ResultNoRows is a pre-defined Result for drivers to return when a DDL +// command (such as a CREATE TABLE) succeeds. It returns an error for both +// LastInsertId and RowsAffected. +var ResultNoRows noRows + +type noRows struct{} + +var _ Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errors.New("no LastInsertId available after DDL statement") +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errors.New("no RowsAffected available after DDL statement") +} diff --git a/src/database/sql/driver/types.go b/src/database/sql/driver/types.go new file mode 100644 index 000000000..3305354df --- /dev/null +++ b/src/database/sql/driver/types.go @@ -0,0 +1,252 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package driver + +import ( + "fmt" + "reflect" + "strconv" + "time" +) + +// ValueConverter is the interface providing the ConvertValue method. +// +// Various implementations of ValueConverter are provided by the +// driver package to provide consistent implementations of conversions +// between drivers. The ValueConverters have several uses: +// +// * converting from the Value types as provided by the sql package +// into a database table's specific column type and making sure it +// fits, such as making sure a particular int64 fits in a +// table's uint16 column. +// +// * converting a value as given from the database into one of the +// driver Value types. +// +// * by the sql package, for converting from a driver's Value type +// to a user's type in a scan. +type ValueConverter interface { + // ConvertValue converts a value to a driver Value. + ConvertValue(v interface{}) (Value, error) +} + +// Valuer is the interface providing the Value method. +// +// Types implementing Valuer interface are able to convert +// themselves to a driver Value. +type Valuer interface { + // Value returns a driver Value. + Value() (Value, error) +} + +// Bool is a ValueConverter that converts input values to bools. +// +// The conversion rules are: +// - booleans are returned unchanged +// - for integer types, +// 1 is true +// 0 is false, +// other integers are an error +// - for strings and []byte, same rules as strconv.ParseBool +// - all other types are an error +var Bool boolType + +type boolType struct{} + +var _ ValueConverter = boolType{} + +func (boolType) String() string { return "Bool" } + +func (boolType) ConvertValue(src interface{}) (Value, error) { + switch s := src.(type) { + case bool: + return s, nil + case string: + b, err := strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s) + } + return b, nil + case []byte: + b, err := strconv.ParseBool(string(s)) + if err != nil { + return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s) + } + return b, nil + } + + sv := reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + iv := sv.Int() + if iv == 1 || iv == 0 { + return iv == 1, nil + } + return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", iv) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + uv := sv.Uint() + if uv == 1 || uv == 0 { + return uv == 1, nil + } + return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", uv) + } + + return nil, fmt.Errorf("sql/driver: couldn't convert %v (%T) into type bool", src, src) +} + +// Int32 is a ValueConverter that converts input values to int64, +// respecting the limits of an int32 value. +var Int32 int32Type + +type int32Type struct{} + +var _ ValueConverter = int32Type{} + +func (int32Type) ConvertValue(v interface{}) (Value, error) { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i64 := rv.Int() + if i64 > (1<<31)-1 || i64 < -(1<<31) { + return nil, fmt.Errorf("sql/driver: value %d overflows int32", v) + } + return i64, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u64 := rv.Uint() + if u64 > (1<<31)-1 { + return nil, fmt.Errorf("sql/driver: value %d overflows int32", v) + } + return int64(u64), nil + case reflect.String: + i, err := strconv.Atoi(rv.String()) + if err != nil { + return nil, fmt.Errorf("sql/driver: value %q can't be converted to int32", v) + } + return int64(i), nil + } + return nil, fmt.Errorf("sql/driver: unsupported value %v (type %T) converting to int32", v, v) +} + +// String is a ValueConverter that converts its input to a string. +// If the value is already a string or []byte, it's unchanged. +// If the value is of another type, conversion to string is done +// with fmt.Sprintf("%v", v). +var String stringType + +type stringType struct{} + +func (stringType) ConvertValue(v interface{}) (Value, error) { + switch v.(type) { + case string, []byte: + return v, nil + } + return fmt.Sprintf("%v", v), nil +} + +// Null is a type that implements ValueConverter by allowing nil +// values but otherwise delegating to another ValueConverter. +type Null struct { + Converter ValueConverter +} + +func (n Null) ConvertValue(v interface{}) (Value, error) { + if v == nil { + return nil, nil + } + return n.Converter.ConvertValue(v) +} + +// NotNull is a type that implements ValueConverter by disallowing nil +// values but otherwise delegating to another ValueConverter. +type NotNull struct { + Converter ValueConverter +} + +func (n NotNull) ConvertValue(v interface{}) (Value, error) { + if v == nil { + return nil, fmt.Errorf("nil value not allowed") + } + return n.Converter.ConvertValue(v) +} + +// IsValue reports whether v is a valid Value parameter type. +// Unlike IsScanValue, IsValue permits the string type. +func IsValue(v interface{}) bool { + if IsScanValue(v) { + return true + } + if _, ok := v.(string); ok { + return true + } + return false +} + +// IsScanValue reports whether v is a valid Value scan type. +// Unlike IsValue, IsScanValue does not permit the string type. +func IsScanValue(v interface{}) bool { + if v == nil { + return true + } + switch v.(type) { + case int64, float64, []byte, bool, time.Time: + return true + } + return false +} + +// DefaultParameterConverter is the default implementation of +// ValueConverter that's used when a Stmt doesn't implement +// ColumnConverter. +// +// DefaultParameterConverter returns the given value directly if +// IsValue(value). Otherwise integer type are converted to +// int64, floats to float64, and strings to []byte. Other types are +// an error. +var DefaultParameterConverter defaultConverter + +type defaultConverter struct{} + +var _ ValueConverter = defaultConverter{} + +func (defaultConverter) ConvertValue(v interface{}) (Value, error) { + if IsValue(v) { + return v, nil + } + + if svi, ok := v.(Valuer); ok { + sv, err := svi.Value() + if err != nil { + return nil, err + } + if !IsValue(sv) { + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + } + return sv, nil + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Ptr: + // indirect pointers + if rv.IsNil() { + return nil, nil + } else { + return defaultConverter{}.ConvertValue(rv.Elem().Interface()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return int64(rv.Uint()), nil + case reflect.Uint64: + u64 := rv.Uint() + if u64 >= 1<<63 { + return nil, fmt.Errorf("uint64 values with high bit set are not supported") + } + return int64(u64), nil + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + } + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) +} diff --git a/src/database/sql/driver/types_test.go b/src/database/sql/driver/types_test.go new file mode 100644 index 000000000..1ce0ff065 --- /dev/null +++ b/src/database/sql/driver/types_test.go @@ -0,0 +1,65 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package driver + +import ( + "reflect" + "testing" + "time" +) + +type valueConverterTest struct { + c ValueConverter + in interface{} + out interface{} + err string +} + +var now = time.Now() +var answer int64 = 42 + +var valueConverterTests = []valueConverterTest{ + {Bool, "true", true, ""}, + {Bool, "True", true, ""}, + {Bool, []byte("t"), true, ""}, + {Bool, true, true, ""}, + {Bool, "1", true, ""}, + {Bool, 1, true, ""}, + {Bool, int64(1), true, ""}, + {Bool, uint16(1), true, ""}, + {Bool, "false", false, ""}, + {Bool, false, false, ""}, + {Bool, "0", false, ""}, + {Bool, 0, false, ""}, + {Bool, int64(0), false, ""}, + {Bool, uint16(0), false, ""}, + {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"}, + {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"}, + {DefaultParameterConverter, now, now, ""}, + {DefaultParameterConverter, (*int64)(nil), nil, ""}, + {DefaultParameterConverter, &answer, answer, ""}, + {DefaultParameterConverter, &now, now, ""}, +} + +func TestValueConverters(t *testing.T) { + for i, tt := range valueConverterTests { + out, err := tt.c.ConvertValue(tt.in) + goterr := "" + if err != nil { + goterr = err.Error() + } + if goterr != tt.err { + t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q", + i, tt.c, tt.in, tt.in, goterr, tt.err) + } + if tt.err != "" { + continue + } + if !reflect.DeepEqual(out, tt.out) { + t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)", + i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out) + } + } +} diff --git a/src/database/sql/example_test.go b/src/database/sql/example_test.go new file mode 100644 index 000000000..dcb74e069 --- /dev/null +++ b/src/database/sql/example_test.go @@ -0,0 +1,46 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql_test + +import ( + "database/sql" + "fmt" + "log" +) + +var db *sql.DB + +func ExampleDB_Query() { + age := 27 + rows, err := db.Query("SELECT name FROM users WHERE age=?", age) + if err != nil { + log.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + log.Fatal(err) + } + fmt.Printf("%s is %d\n", name, age) + } + if err := rows.Err(); err != nil { + log.Fatal(err) + } +} + +func ExampleDB_QueryRow() { + id := 123 + var username string + err := db.QueryRow("SELECT username FROM users WHERE id=?", id).Scan(&username) + switch { + case err == sql.ErrNoRows: + log.Printf("No user with that ID.") + case err != nil: + log.Fatal(err) + default: + fmt.Printf("Username is %s\n", username) + } +} diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go new file mode 100644 index 000000000..a993fd46e --- /dev/null +++ b/src/database/sql/fakedb_test.go @@ -0,0 +1,829 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + "log" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +var _ = log.Printf + +// fakeDriver is a fake database that implements Go's driver.Driver +// interface, just for testing. +// +// It speaks a query language that's semantically similar to but +// syntactically different and simpler than SQL. The syntax is as +// follows: +// +// WIPE +// CREATE|<tablename>|<col>=<type>,<col>=<type>,... +// where types are: "string", [u]int{8,16,32,64}, "bool" +// INSERT|<tablename>|col=val,col2=val2,col3=? +// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? +// +// When opening a fakeDriver's database, it starts empty with no +// tables. All tables and data are stored in memory only. +type fakeDriver struct { + mu sync.Mutex // guards 3 following fields + openCount int // conn opens + closeCount int // conn closes + waitCh chan struct{} + waitingCh chan struct{} + dbs map[string]*fakeDB +} + +type fakeDB struct { + name string + + mu sync.Mutex + free []*fakeConn + tables map[string]*table + badConn bool +} + +type table struct { + mu sync.Mutex + colname []string + coltype []string + rows []*row +} + +func (t *table) columnIndex(name string) int { + for n, nname := range t.colname { + if name == nname { + return n + } + } + return -1 +} + +type row struct { + cols []interface{} // must be same size as its table colname + coltype +} + +func (r *row) clone() *row { + nrow := &row{cols: make([]interface{}, len(r.cols))} + copy(nrow.cols, r.cols) + return nrow +} + +type fakeConn struct { + db *fakeDB // where to return ourselves to + + currTx *fakeTx + + // Stats for tests: + mu sync.Mutex + stmtsMade int + stmtsClosed int + numPrepare int + bad bool +} + +func (c *fakeConn) incrStat(v *int) { + c.mu.Lock() + *v++ + c.mu.Unlock() +} + +type fakeTx struct { + c *fakeConn +} + +type fakeStmt struct { + c *fakeConn + q string // just for debugging + + cmd string + table string + + closed bool + + colName []string // used by CREATE, INSERT, SELECT (selected columns) + colType []string // used by CREATE + colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) + placeholders int // used by INSERT/SELECT: number of ? params + + whereCol []string // used by SELECT (all placeholders) + + placeholderConverter []driver.ValueConverter // used by INSERT +} + +var fdriver driver.Driver = &fakeDriver{} + +func init() { + Register("test", fdriver) +} + +func contains(list []string, y string) bool { + for _, x := range list { + if x == y { + return true + } + } + return false +} + +type Dummy struct { + driver.Driver +} + +func TestDrivers(t *testing.T) { + unregisterAllDrivers() + Register("test", fdriver) + Register("invalid", Dummy{}) + all := Drivers() + if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { + t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) + } +} + +// Supports dsn forms: +// <dbname> +// <dbname>;<opts> (only currently supported option is `badConn`, +// which causes driver.ErrBadConn to be returned on +// every other conn.Begin()) +func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { + parts := strings.Split(dsn, ";") + if len(parts) < 1 { + return nil, errors.New("fakedb: no database name") + } + name := parts[0] + + db := d.getDB(name) + + d.mu.Lock() + d.openCount++ + d.mu.Unlock() + conn := &fakeConn{db: db} + + if len(parts) >= 2 && parts[1] == "badConn" { + conn.bad = true + } + if d.waitCh != nil { + d.waitingCh <- struct{}{} + <-d.waitCh + d.waitCh = nil + d.waitingCh = nil + } + return conn, nil +} + +func (d *fakeDriver) getDB(name string) *fakeDB { + d.mu.Lock() + defer d.mu.Unlock() + if d.dbs == nil { + d.dbs = make(map[string]*fakeDB) + } + db, ok := d.dbs[name] + if !ok { + db = &fakeDB{name: name} + d.dbs[name] = db + } + return db +} + +func (db *fakeDB) wipe() { + db.mu.Lock() + defer db.mu.Unlock() + db.tables = nil +} + +func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { + db.mu.Lock() + defer db.mu.Unlock() + if db.tables == nil { + db.tables = make(map[string]*table) + } + if _, exist := db.tables[name]; exist { + return fmt.Errorf("table %q already exists", name) + } + if len(columnNames) != len(columnTypes) { + return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", + name, len(columnNames), len(columnTypes)) + } + db.tables[name] = &table{colname: columnNames, coltype: columnTypes} + return nil +} + +// must be called with db.mu lock held +func (db *fakeDB) table(table string) (*table, bool) { + if db.tables == nil { + return nil, false + } + t, ok := db.tables[table] + return t, ok +} + +func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { + db.mu.Lock() + defer db.mu.Unlock() + t, ok := db.table(table) + if !ok { + return + } + for n, cname := range t.colname { + if cname == column { + return t.coltype[n], true + } + } + return "", false +} + +func (c *fakeConn) isBad() bool { + // if not simulating bad conn, do nothing + if !c.bad { + return false + } + // alternate between bad conn and not bad conn + c.db.badConn = !c.db.badConn + return c.db.badConn +} + +func (c *fakeConn) Begin() (driver.Tx, error) { + if c.isBad() { + return nil, driver.ErrBadConn + } + if c.currTx != nil { + return nil, errors.New("already in a transaction") + } + c.currTx = &fakeTx{c: c} + return c.currTx, nil +} + +var hookPostCloseConn struct { + sync.Mutex + fn func(*fakeConn, error) +} + +func setHookpostCloseConn(fn func(*fakeConn, error)) { + hookPostCloseConn.Lock() + defer hookPostCloseConn.Unlock() + hookPostCloseConn.fn = fn +} + +var testStrictClose *testing.T + +// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close +// fails to close. If nil, the check is disabled. +func setStrictFakeConnClose(t *testing.T) { + testStrictClose = t +} + +func (c *fakeConn) Close() (err error) { + drv := fdriver.(*fakeDriver) + defer func() { + if err != nil && testStrictClose != nil { + testStrictClose.Errorf("failed to close a test fakeConn: %v", err) + } + hookPostCloseConn.Lock() + fn := hookPostCloseConn.fn + hookPostCloseConn.Unlock() + if fn != nil { + fn(c, err) + } + if err == nil { + drv.mu.Lock() + drv.closeCount++ + drv.mu.Unlock() + } + }() + if c.currTx != nil { + return errors.New("can't close fakeConn; in a Transaction") + } + if c.db == nil { + return errors.New("can't close fakeConn; already closed") + } + if c.stmtsMade > c.stmtsClosed { + return errors.New("can't close; dangling statement(s)") + } + c.db = nil + return nil +} + +func checkSubsetTypes(args []driver.Value) error { + for n, arg := range args { + switch arg.(type) { + case int64, float64, bool, nil, []byte, string, time.Time: + default: + return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) + } + } + return nil +} + +func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args are of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + +func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args are of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + +func errf(msg string, args ...interface{}) error { + return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) +} + +// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? +// (note that where columns must always contain ? marks, +// just a limitation for fakedb) +func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { + if len(parts) != 3 { + stmt.Close() + return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) + } + stmt.table = parts[0] + stmt.colName = strings.Split(parts[1], ",") + for n, colspec := range strings.Split(parts[2], ",") { + if colspec == "" { + continue + } + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + _, ok := c.db.columnType(stmt.table, column) + if !ok { + stmt.Close() + return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) + } + if value != "?" { + stmt.Close() + return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", + stmt.table, column) + } + stmt.whereCol = append(stmt.whereCol, column) + stmt.placeholders++ + } + return stmt, nil +} + +// parts are table|col=type,col2=type2 +func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameType := strings.Split(colspec, "=") + if len(nameType) != 2 { + stmt.Close() + return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + stmt.colName = append(stmt.colName, nameType[0]) + stmt.colType = append(stmt.colType, nameType[1]) + } + return stmt, nil +} + +// parts are table|col=?,col2=val +func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + ctype, ok := c.db.columnType(stmt.table, column) + if !ok { + stmt.Close() + return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) + } + stmt.colName = append(stmt.colName, column) + + if value != "?" { + var subsetVal interface{} + // Convert to driver subset type + switch ctype { + case "string": + subsetVal = []byte(value) + case "blob": + subsetVal = []byte(value) + case "int32": + i, err := strconv.Atoi(value) + if err != nil { + stmt.Close() + return nil, errf("invalid conversion to int32 from %q", value) + } + subsetVal = int64(i) // int64 is a subset type, but not int32 + default: + stmt.Close() + return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) + } + stmt.colValue = append(stmt.colValue, subsetVal) + } else { + stmt.placeholders++ + stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) + stmt.colValue = append(stmt.colValue, "?") + } + } + return stmt, nil +} + +// hook to simulate broken connections +var hookPrepareBadConn func() bool + +func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { + c.numPrepare++ + if c.db == nil { + panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) + } + + if hookPrepareBadConn != nil && hookPrepareBadConn() { + return nil, driver.ErrBadConn + } + + parts := strings.Split(query, "|") + if len(parts) < 1 { + return nil, errf("empty query") + } + cmd := parts[0] + parts = parts[1:] + stmt := &fakeStmt{q: query, c: c, cmd: cmd} + c.incrStat(&c.stmtsMade) + switch cmd { + case "WIPE": + // Nothing + case "SELECT": + return c.prepareSelect(stmt, parts) + case "CREATE": + return c.prepareCreate(stmt, parts) + case "INSERT": + return c.prepareInsert(stmt, parts) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the row. + // Used for some of the concurrent tests. + return c.prepareInsert(stmt, parts) + default: + stmt.Close() + return nil, errf("unsupported command type %q", cmd) + } + return stmt, nil +} + +func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { + if len(s.placeholderConverter) == 0 { + return driver.DefaultParameterConverter + } + return s.placeholderConverter[idx] +} + +func (s *fakeStmt) Close() error { + if s.c == nil { + panic("nil conn in fakeStmt.Close") + } + if s.c.db == nil { + panic("in fakeStmt.Close, conn's db is nil (already closed)") + } + if !s.closed { + s.c.incrStat(&s.c.stmtsClosed) + s.closed = true + } + return nil +} + +var errClosed = errors.New("fakedb: statement has been closed") + +// hook to simulate broken connections +var hookExecBadConn func() bool + +func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { + if s.closed { + return nil, errClosed + } + + if hookExecBadConn != nil && hookExecBadConn() { + return nil, driver.ErrBadConn + } + + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + + db := s.c.db + switch s.cmd { + case "WIPE": + db.wipe() + return driver.ResultNoRows, nil + case "CREATE": + if err := db.createTable(s.table, s.colName, s.colType); err != nil { + return nil, err + } + return driver.ResultNoRows, nil + case "INSERT": + return s.execInsert(args, true) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the row. + // Used for some of the concurrent tests. + return s.execInsert(args, false) + } + fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) + return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) +} + +// When doInsert is true, add the row to the table. +// When doInsert is false do prep-work and error checking, but don't +// actually add the row to the table. +func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) { + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + + t.mu.Lock() + defer t.mu.Unlock() + + var cols []interface{} + if doInsert { + cols = make([]interface{}, len(t.colname)) + } + argPos := 0 + for n, colname := range s.colName { + colidx := t.columnIndex(colname) + if colidx == -1 { + return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) + } + var val interface{} + if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { + val = args[argPos] + argPos++ + } else { + val = s.colValue[n] + } + if doInsert { + cols[colidx] = val + } + } + + if doInsert { + t.rows = append(t.rows, &row{cols: cols}) + } + return driver.RowsAffected(1), nil +} + +// hook to simulate broken connections +var hookQueryBadConn func() bool + +func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { + if s.closed { + return nil, errClosed + } + + if hookQueryBadConn != nil && hookQueryBadConn() { + return nil, driver.ErrBadConn + } + + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + + if s.table == "magicquery" { + if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { + if args[0] == "sleep" { + time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) + } + } + } + + t.mu.Lock() + defer t.mu.Unlock() + + colIdx := make(map[string]int) // select column name -> column index in table + for _, name := range s.colName { + idx := t.columnIndex(name) + if idx == -1 { + return nil, fmt.Errorf("fakedb: unknown column name %q", name) + } + colIdx[name] = idx + } + + mrows := []*row{} +rows: + for _, trow := range t.rows { + // Process the where clause, skipping non-match rows. This is lazy + // and just uses fmt.Sprintf("%v") to test equality. Good enough + // for test code. + for widx, wcol := range s.whereCol { + idx := t.columnIndex(wcol) + if idx == -1 { + return nil, fmt.Errorf("db: invalid where clause column %q", wcol) + } + tcol := trow.cols[idx] + if bs, ok := tcol.([]byte); ok { + // lazy hack to avoid sprintf %v on a []byte + tcol = string(bs) + } + if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { + continue rows + } + } + mrow := &row{cols: make([]interface{}, len(s.colName))} + for seli, name := range s.colName { + mrow.cols[seli] = trow.cols[colIdx[name]] + } + mrows = append(mrows, mrow) + } + + cursor := &rowsCursor{ + pos: -1, + rows: mrows, + cols: s.colName, + errPos: -1, + } + return cursor, nil +} + +func (s *fakeStmt) NumInput() int { + return s.placeholders +} + +func (tx *fakeTx) Commit() error { + tx.c.currTx = nil + return nil +} + +func (tx *fakeTx) Rollback() error { + tx.c.currTx = nil + return nil +} + +type rowsCursor struct { + cols []string + pos int + rows []*row + closed bool + + // errPos and err are for making Next return early with error. + errPos int + err error + + // a clone of slices to give out to clients, indexed by the + // the original slice's first byte address. we clone them + // just so we're able to corrupt them on close. + bytesClone map[*byte][]byte +} + +func (rc *rowsCursor) Close() error { + if !rc.closed { + for _, bs := range rc.bytesClone { + bs[0] = 255 // first byte corrupted + } + } + rc.closed = true + return nil +} + +func (rc *rowsCursor) Columns() []string { + return rc.cols +} + +var rowsCursorNextHook func(dest []driver.Value) error + +func (rc *rowsCursor) Next(dest []driver.Value) error { + if rowsCursorNextHook != nil { + return rowsCursorNextHook(dest) + } + + if rc.closed { + return errors.New("fakedb: cursor is closed") + } + rc.pos++ + if rc.pos == rc.errPos { + return rc.err + } + if rc.pos >= len(rc.rows) { + return io.EOF // per interface spec + } + for i, v := range rc.rows[rc.pos].cols { + // TODO(bradfitz): convert to subset types? naah, I + // think the subset types should only be input to + // driver, but the sql package should be able to handle + // a wider range of types coming out of drivers. all + // for ease of drivers, and to prevent drivers from + // messing up conversions or doing them differently. + dest[i] = v + + if bs, ok := v.([]byte); ok { + if rc.bytesClone == nil { + rc.bytesClone = make(map[*byte][]byte) + } + clone, ok := rc.bytesClone[&bs[0]] + if !ok { + clone = make([]byte, len(bs)) + copy(clone, bs) + rc.bytesClone[&bs[0]] = clone + } + dest[i] = clone + } + } + return nil +} + +// fakeDriverString is like driver.String, but indirects pointers like +// DefaultValueConverter. +// +// This could be surprising behavior to retroactively apply to +// driver.String now that Go1 is out, but this is convenient for +// our TestPointerParamsAndScans. +// +type fakeDriverString struct{} + +func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { + switch c := v.(type) { + case string, []byte: + return v, nil + case *string: + if c == nil { + return nil, nil + } + return *c, nil + } + return fmt.Sprintf("%v", v), nil +} + +func converterForType(typ string) driver.ValueConverter { + switch typ { + case "bool": + return driver.Bool + case "nullbool": + return driver.Null{Converter: driver.Bool} + case "int32": + return driver.Int32 + case "string": + return driver.NotNull{Converter: fakeDriverString{}} + case "nullstring": + return driver.Null{Converter: fakeDriverString{}} + case "int64": + // TODO(coopernurse): add type-specific converter + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nullint64": + // TODO(coopernurse): add type-specific converter + return driver.Null{Converter: driver.DefaultParameterConverter} + case "float64": + // TODO(coopernurse): add type-specific converter + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nullfloat64": + // TODO(coopernurse): add type-specific converter + return driver.Null{Converter: driver.DefaultParameterConverter} + case "datetime": + return driver.DefaultParameterConverter + } + panic("invalid fakedb column type of " + typ) +} diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go new file mode 100644 index 000000000..6e6f246ae --- /dev/null +++ b/src/database/sql/sql.go @@ -0,0 +1,1770 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sql provides a generic interface around SQL (or SQL-like) +// databases. +// +// The sql package must be used in conjunction with a database driver. +// See http://golang.org/s/sqldrivers for a list of drivers. +// +// For more usage examples, see the wiki page at +// http://golang.org/s/sqlwiki. +package sql + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + "runtime" + "sort" + "sync" +) + +var drivers = make(map[string]driver.Driver) + +// Register makes a database driver available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, driver driver.Driver) { + if driver == nil { + panic("sql: Register driver is nil") + } + if _, dup := drivers[name]; dup { + panic("sql: Register called twice for driver " + name) + } + drivers[name] = driver +} + +func unregisterAllDrivers() { + // For tests. + drivers = make(map[string]driver.Driver) +} + +// Drivers returns a sorted list of the names of the registered drivers. +func Drivers() []string { + var list []string + for name := range drivers { + list = append(list, name) + } + sort.Strings(list) + return list +} + +// RawBytes is a byte slice that holds a reference to memory owned by +// the database itself. After a Scan into a RawBytes, the slice is only +// valid until the next call to Next, Scan, or Close. +type RawBytes []byte + +// NullString represents a string that may be null. +// NullString implements the Scanner interface so +// it can be used as a scan destination: +// +// var s NullString +// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s) +// ... +// if s.Valid { +// // use s.String +// } else { +// // NULL value +// } +// +type NullString struct { + String string + Valid bool // Valid is true if String is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullString) Scan(value interface{}) error { + if value == nil { + ns.String, ns.Valid = "", false + return nil + } + ns.Valid = true + return convertAssign(&ns.String, value) +} + +// Value implements the driver Valuer interface. +func (ns NullString) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return ns.String, nil +} + +// NullInt64 represents an int64 that may be null. +// NullInt64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullInt64 struct { + Int64 int64 + Valid bool // Valid is true if Int64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt64) Scan(value interface{}) error { + if value == nil { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Int64, value) +} + +// Value implements the driver Valuer interface. +func (n NullInt64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Int64, nil +} + +// NullFloat64 represents a float64 that may be null. +// NullFloat64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullFloat64 struct { + Float64 float64 + Valid bool // Valid is true if Float64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullFloat64) Scan(value interface{}) error { + if value == nil { + n.Float64, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Float64, value) +} + +// Value implements the driver Valuer interface. +func (n NullFloat64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Float64, nil +} + +// NullBool represents a bool that may be null. +// NullBool implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullBool struct { + Bool bool + Valid bool // Valid is true if Bool is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullBool) Scan(value interface{}) error { + if value == nil { + n.Bool, n.Valid = false, false + return nil + } + n.Valid = true + return convertAssign(&n.Bool, value) +} + +// Value implements the driver Valuer interface. +func (n NullBool) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Bool, nil +} + +// Scanner is an interface used by Scan. +type Scanner interface { + // Scan assigns a value from a database driver. + // + // The src value will be of one of the following restricted + // set of types: + // + // int64 + // float64 + // bool + // []byte + // string + // time.Time + // nil - for NULL values + // + // An error should be returned if the value can not be stored + // without loss of information. + Scan(src interface{}) error +} + +// ErrNoRows is returned by Scan when QueryRow doesn't return a +// row. In such a case, QueryRow returns a placeholder *Row value that +// defers this error until a Scan. +var ErrNoRows = errors.New("sql: no rows in result set") + +// DB is a database handle representing a pool of zero or more +// underlying connections. It's safe for concurrent use by multiple +// goroutines. +// +// The sql package creates and frees connections automatically; it +// also maintains a free pool of idle connections. If the database has +// a concept of per-connection state, such state can only be reliably +// observed within a transaction. Once DB.Begin is called, the +// returned Tx is bound to a single connection. Once Commit or +// Rollback is called on the transaction, that transaction's +// connection is returned to DB's idle connection pool. The pool size +// can be controlled with SetMaxIdleConns. +type DB struct { + driver driver.Driver + dsn string + + mu sync.Mutex // protects following fields + freeConn []*driverConn + connRequests []chan connRequest + numOpen int + pendingOpens int + // Used to signal the need for new connections + // a goroutine running connectionOpener() reads on this chan and + // maybeOpenNewConnections sends on the chan (one send per needed connection) + // It is closed during db.Close(). The close tells the connectionOpener + // goroutine to exit. + openerCh chan struct{} + closed bool + dep map[finalCloser]depSet + lastPut map[*driverConn]string // stacktrace of last conn's put; debug only + maxIdle int // zero means defaultMaxIdleConns; negative means 0 + maxOpen int // <= 0 means unlimited +} + +// driverConn wraps a driver.Conn with a mutex, to +// be held during all calls into the Conn. (including any calls onto +// interfaces returned via that Conn, such as calls on Tx, Stmt, +// Result, Rows) +type driverConn struct { + db *DB + + sync.Mutex // guards following + ci driver.Conn + closed bool + finalClosed bool // ci.Close has been called + openStmt map[driver.Stmt]bool + + // guarded by db.mu + inUse bool + onPut []func() // code (with db.mu held) run when conn is next returned + dbmuClosed bool // same as closed, but guarded by db.mu, for connIfFree +} + +func (dc *driverConn) releaseConn(err error) { + dc.db.putConn(dc, err) +} + +func (dc *driverConn) removeOpenStmt(si driver.Stmt) { + dc.Lock() + defer dc.Unlock() + delete(dc.openStmt, si) +} + +func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) { + si, err := dc.ci.Prepare(query) + if err == nil { + // Track each driverConn's open statements, so we can close them + // before closing the conn. + // + // TODO(bradfitz): let drivers opt out of caring about + // stmt closes if the conn is about to close anyway? For now + // do the safe thing, in case stmts need to be closed. + // + // TODO(bradfitz): after Go 1.2, closing driver.Stmts + // should be moved to driverStmt, using unique + // *driverStmts everywhere (including from + // *Stmt.connStmt, instead of returning a + // driver.Stmt), using driverStmt as a pointer + // everywhere, and making it a finalCloser. + if dc.openStmt == nil { + dc.openStmt = make(map[driver.Stmt]bool) + } + dc.openStmt[si] = true + } + return si, err +} + +// the dc.db's Mutex is held. +func (dc *driverConn) closeDBLocked() func() error { + dc.Lock() + defer dc.Unlock() + if dc.closed { + return func() error { return errors.New("sql: duplicate driverConn close") } + } + dc.closed = true + return dc.db.removeDepLocked(dc, dc) +} + +func (dc *driverConn) Close() error { + dc.Lock() + if dc.closed { + dc.Unlock() + return errors.New("sql: duplicate driverConn close") + } + dc.closed = true + dc.Unlock() // not defer; removeDep finalClose calls may need to lock + + // And now updates that require holding dc.mu.Lock. + dc.db.mu.Lock() + dc.dbmuClosed = true + fn := dc.db.removeDepLocked(dc, dc) + dc.db.mu.Unlock() + return fn() +} + +func (dc *driverConn) finalClose() error { + dc.Lock() + + for si := range dc.openStmt { + si.Close() + } + dc.openStmt = nil + + err := dc.ci.Close() + dc.ci = nil + dc.finalClosed = true + dc.Unlock() + + dc.db.mu.Lock() + dc.db.numOpen-- + dc.db.maybeOpenNewConnections() + dc.db.mu.Unlock() + + return err +} + +// driverStmt associates a driver.Stmt with the +// *driverConn from which it came, so the driverConn's lock can be +// held during calls. +type driverStmt struct { + sync.Locker // the *driverConn + si driver.Stmt +} + +func (ds *driverStmt) Close() error { + ds.Lock() + defer ds.Unlock() + return ds.si.Close() +} + +// depSet is a finalCloser's outstanding dependencies +type depSet map[interface{}]bool // set of true bools + +// The finalCloser interface is used by (*DB).addDep and related +// dependency reference counting. +type finalCloser interface { + // finalClose is called when the reference count of an object + // goes to zero. (*DB).mu is not held while calling it. + finalClose() error +} + +// addDep notes that x now depends on dep, and x's finalClose won't be +// called until all of x's dependencies are removed with removeDep. +func (db *DB) addDep(x finalCloser, dep interface{}) { + //println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep)) + db.mu.Lock() + defer db.mu.Unlock() + db.addDepLocked(x, dep) +} + +func (db *DB) addDepLocked(x finalCloser, dep interface{}) { + if db.dep == nil { + db.dep = make(map[finalCloser]depSet) + } + xdep := db.dep[x] + if xdep == nil { + xdep = make(depSet) + db.dep[x] = xdep + } + xdep[dep] = true +} + +// removeDep notes that x no longer depends on dep. +// If x still has dependencies, nil is returned. +// If x no longer has any dependencies, its finalClose method will be +// called and its error value will be returned. +func (db *DB) removeDep(x finalCloser, dep interface{}) error { + db.mu.Lock() + fn := db.removeDepLocked(x, dep) + db.mu.Unlock() + return fn() +} + +func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error { + //println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep)) + + xdep, ok := db.dep[x] + if !ok { + panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x)) + } + + l0 := len(xdep) + delete(xdep, dep) + + switch len(xdep) { + case l0: + // Nothing removed. Shouldn't happen. + panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x)) + case 0: + // No more dependencies. + delete(db.dep, x) + return x.finalClose + default: + // Dependencies remain. + return func() error { return nil } + } +} + +// This is the size of the connectionOpener request chan (dn.openerCh). +// This value should be larger than the maximum typical value +// used for db.maxOpen. If maxOpen is significantly larger than +// connectionRequestQueueSize then it is possible for ALL calls into the *DB +// to block until the connectionOpener can satisfy the backlog of requests. +var connectionRequestQueueSize = 1000000 + +// Open opens a database specified by its database driver name and a +// driver-specific data source name, usually consisting of at least a +// database name and connection information. +// +// Most users will open a database via a driver-specific connection +// helper function that returns a *DB. No database drivers are included +// in the Go standard library. See http://golang.org/s/sqldrivers for +// a list of third-party drivers. +// +// Open may just validate its arguments without creating a connection +// to the database. To verify that the data source name is valid, call +// Ping. +// +// The returned DB is safe for concurrent use by multiple goroutines +// and maintains its own pool of idle connections. Thus, the Open +// function should be called just once. It is rarely necessary to +// close a DB. +func Open(driverName, dataSourceName string) (*DB, error) { + driveri, ok := drivers[driverName] + if !ok { + return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) + } + db := &DB{ + driver: driveri, + dsn: dataSourceName, + openerCh: make(chan struct{}, connectionRequestQueueSize), + lastPut: make(map[*driverConn]string), + } + go db.connectionOpener() + return db, nil +} + +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (db *DB) Ping() error { + // TODO(bradfitz): give drivers an optional hook to implement + // this in a more efficient or more reliable way, if they + // have one. + dc, err := db.conn() + if err != nil { + return err + } + db.putConn(dc, nil) + return nil +} + +// Close closes the database, releasing any open resources. +// +// It is rare to Close a DB, as the DB handle is meant to be +// long-lived and shared between many goroutines. +func (db *DB) Close() error { + db.mu.Lock() + if db.closed { // Make DB.Close idempotent + db.mu.Unlock() + return nil + } + close(db.openerCh) + var err error + fns := make([]func() error, 0, len(db.freeConn)) + for _, dc := range db.freeConn { + fns = append(fns, dc.closeDBLocked()) + } + db.freeConn = nil + db.closed = true + for _, req := range db.connRequests { + close(req) + } + db.mu.Unlock() + for _, fn := range fns { + err1 := fn() + if err1 != nil { + err = err1 + } + } + return err +} + +const defaultMaxIdleConns = 2 + +func (db *DB) maxIdleConnsLocked() int { + n := db.maxIdle + switch { + case n == 0: + // TODO(bradfitz): ask driver, if supported, for its default preference + return defaultMaxIdleConns + case n < 0: + return 0 + default: + return n + } +} + +// SetMaxIdleConns sets the maximum number of connections in the idle +// connection pool. +// +// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns +// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit +// +// If n <= 0, no idle connections are retained. +func (db *DB) SetMaxIdleConns(n int) { + db.mu.Lock() + if n > 0 { + db.maxIdle = n + } else { + // No idle connections. + db.maxIdle = -1 + } + // Make sure maxIdle doesn't exceed maxOpen + if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen { + db.maxIdle = db.maxOpen + } + var closing []*driverConn + idleCount := len(db.freeConn) + maxIdle := db.maxIdleConnsLocked() + if idleCount > maxIdle { + closing = db.freeConn[maxIdle:] + db.freeConn = db.freeConn[:maxIdle] + } + db.mu.Unlock() + for _, c := range closing { + c.Close() + } +} + +// SetMaxOpenConns sets the maximum number of open connections to the database. +// +// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than +// MaxIdleConns, then MaxIdleConns will be reduced to match the new +// MaxOpenConns limit +// +// If n <= 0, then there is no limit on the number of open connections. +// The default is 0 (unlimited). +func (db *DB) SetMaxOpenConns(n int) { + db.mu.Lock() + db.maxOpen = n + if n < 0 { + db.maxOpen = 0 + } + syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen + db.mu.Unlock() + if syncMaxIdle { + db.SetMaxIdleConns(n) + } +} + +// Assumes db.mu is locked. +// If there are connRequests and the connection limit hasn't been reached, +// then tell the connectionOpener to open new connections. +func (db *DB) maybeOpenNewConnections() { + numRequests := len(db.connRequests) - db.pendingOpens + if db.maxOpen > 0 { + numCanOpen := db.maxOpen - (db.numOpen + db.pendingOpens) + if numRequests > numCanOpen { + numRequests = numCanOpen + } + } + for numRequests > 0 { + db.pendingOpens++ + numRequests-- + db.openerCh <- struct{}{} + } +} + +// Runs in a separate goroutine, opens new connections when requested. +func (db *DB) connectionOpener() { + for range db.openerCh { + db.openNewConnection() + } +} + +// Open one new connection +func (db *DB) openNewConnection() { + ci, err := db.driver.Open(db.dsn) + db.mu.Lock() + defer db.mu.Unlock() + if db.closed { + if err == nil { + ci.Close() + } + return + } + db.pendingOpens-- + if err != nil { + db.putConnDBLocked(nil, err) + return + } + dc := &driverConn{ + db: db, + ci: ci, + } + if db.putConnDBLocked(dc, err) { + db.addDepLocked(dc, dc) + db.numOpen++ + } else { + ci.Close() + } +} + +// connRequest represents one request for a new connection +// When there are no idle connections available, DB.conn will create +// a new connRequest and put it on the db.connRequests list. +type connRequest struct { + conn *driverConn + err error +} + +var errDBClosed = errors.New("sql: database is closed") + +// conn returns a newly-opened or cached *driverConn +func (db *DB) conn() (*driverConn, error) { + db.mu.Lock() + if db.closed { + db.mu.Unlock() + return nil, errDBClosed + } + + // If db.maxOpen > 0 and the number of open connections is over the limit + // and there are no free connection, make a request and wait. + if db.maxOpen > 0 && db.numOpen >= db.maxOpen && len(db.freeConn) == 0 { + // Make the connRequest channel. It's buffered so that the + // connectionOpener doesn't block while waiting for the req to be read. + req := make(chan connRequest, 1) + db.connRequests = append(db.connRequests, req) + db.maybeOpenNewConnections() + db.mu.Unlock() + ret := <-req + return ret.conn, ret.err + } + + if c := len(db.freeConn); c > 0 { + conn := db.freeConn[0] + copy(db.freeConn, db.freeConn[1:]) + db.freeConn = db.freeConn[:c-1] + conn.inUse = true + db.mu.Unlock() + return conn, nil + } + + db.numOpen++ // optimistically + db.mu.Unlock() + ci, err := db.driver.Open(db.dsn) + if err != nil { + db.mu.Lock() + db.numOpen-- // correct for earlier optimism + db.mu.Unlock() + return nil, err + } + db.mu.Lock() + dc := &driverConn{ + db: db, + ci: ci, + } + db.addDepLocked(dc, dc) + dc.inUse = true + db.mu.Unlock() + return dc, nil +} + +var ( + errConnClosed = errors.New("database/sql: internal sentinel error: conn is closed") + errConnBusy = errors.New("database/sql: internal sentinel error: conn is busy") +) + +// connIfFree returns (wanted, nil) if wanted is still a valid conn and +// isn't in use. +// +// The error is errConnClosed if the connection if the requested connection +// is invalid because it's been closed. +// +// The error is errConnBusy if the connection is in use. +func (db *DB) connIfFree(wanted *driverConn) (*driverConn, error) { + db.mu.Lock() + defer db.mu.Unlock() + if wanted.dbmuClosed { + return nil, errConnClosed + } + if wanted.inUse { + return nil, errConnBusy + } + idx := -1 + for ii, v := range db.freeConn { + if v == wanted { + idx = ii + break + } + } + if idx >= 0 { + db.freeConn = append(db.freeConn[:idx], db.freeConn[idx+1:]...) + wanted.inUse = true + return wanted, nil + } + // TODO(bradfitz): shouldn't get here. After Go 1.1, change this to: + // panic("connIfFree call requested a non-closed, non-busy, non-free conn") + // Which passes all the tests, but I'm too paranoid to include this + // late in Go 1.1. + // Instead, treat it like a busy connection: + return nil, errConnBusy +} + +// putConnHook is a hook for testing. +var putConnHook func(*DB, *driverConn) + +// noteUnusedDriverStatement notes that si is no longer used and should +// be closed whenever possible (when c is next not in use), unless c is +// already closed. +func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { + db.mu.Lock() + defer db.mu.Unlock() + if c.inUse { + c.onPut = append(c.onPut, func() { + si.Close() + }) + } else { + c.Lock() + defer c.Unlock() + if !c.finalClosed { + si.Close() + } + } +} + +// debugGetPut determines whether getConn & putConn calls' stack traces +// are returned for more verbose crashes. +const debugGetPut = false + +// putConn adds a connection to the db's free pool. +// err is optionally the last error that occurred on this connection. +func (db *DB) putConn(dc *driverConn, err error) { + db.mu.Lock() + if !dc.inUse { + if debugGetPut { + fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc]) + } + panic("sql: connection returned that was never out") + } + if debugGetPut { + db.lastPut[dc] = stack() + } + dc.inUse = false + + for _, fn := range dc.onPut { + fn() + } + dc.onPut = nil + + if err == driver.ErrBadConn { + // Don't reuse bad connections. + // Since the conn is considered bad and is being discarded, treat it + // as closed. Don't decrement the open count here, finalClose will + // take care of that. + db.maybeOpenNewConnections() + db.mu.Unlock() + dc.Close() + return + } + if putConnHook != nil { + putConnHook(db, dc) + } + added := db.putConnDBLocked(dc, nil) + db.mu.Unlock() + + if !added { + dc.Close() + } +} + +// Satisfy a connRequest or put the driverConn in the idle pool and return true +// or return false. +// putConnDBLocked will satisfy a connRequest if there is one, or it will +// return the *driverConn to the freeConn list if err == nil and the idle +// connection limit will not be exceeded. +// If err != nil, the value of dc is ignored. +// If err == nil, then dc must not equal nil. +// If a connRequest was fulfilled or the *driverConn was placed in the +// freeConn list, then true is returned, otherwise false is returned. +func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { + if c := len(db.connRequests); c > 0 { + req := db.connRequests[0] + // This copy is O(n) but in practice faster than a linked list. + // TODO: consider compacting it down less often and + // moving the base instead? + copy(db.connRequests, db.connRequests[1:]) + db.connRequests = db.connRequests[:c-1] + if err == nil { + dc.inUse = true + } + req <- connRequest{ + conn: dc, + err: err, + } + return true + } else if err == nil && !db.closed && db.maxIdleConnsLocked() > len(db.freeConn) { + db.freeConn = append(db.freeConn, dc) + return true + } + return false +} + +// maxBadConnRetries is the number of maximum retries if the driver returns +// driver.ErrBadConn to signal a broken connection. +const maxBadConnRetries = 10 + +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +func (db *DB) Prepare(query string) (*Stmt, error) { + var stmt *Stmt + var err error + for i := 0; i < maxBadConnRetries; i++ { + stmt, err = db.prepare(query) + if err != driver.ErrBadConn { + break + } + } + return stmt, err +} + +func (db *DB) prepare(query string) (*Stmt, error) { + // TODO: check if db.driver supports an optional + // driver.Preparer interface and call that instead, if so, + // otherwise we make a prepared statement that's bound + // to a connection, and to execute this prepared statement + // we either need to use this connection (if it's free), else + // get a new connection + re-prepare + execute on that one. + dc, err := db.conn() + if err != nil { + return nil, err + } + dc.Lock() + si, err := dc.prepareLocked(query) + dc.Unlock() + if err != nil { + db.putConn(dc, err) + return nil, err + } + stmt := &Stmt{ + db: db, + query: query, + css: []connStmt{{dc, si}}, + } + db.addDep(stmt, stmt) + db.putConn(dc, nil) + return stmt, nil +} + +// Exec executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +func (db *DB) Exec(query string, args ...interface{}) (Result, error) { + var res Result + var err error + for i := 0; i < maxBadConnRetries; i++ { + res, err = db.exec(query, args) + if err != driver.ErrBadConn { + break + } + } + return res, err +} + +func (db *DB) exec(query string, args []interface{}) (res Result, err error) { + dc, err := db.conn() + if err != nil { + return nil, err + } + defer func() { + db.putConn(dc, err) + }() + + if execer, ok := dc.ci.(driver.Execer); ok { + dargs, err := driverArgs(nil, args) + if err != nil { + return nil, err + } + dc.Lock() + resi, err := execer.Exec(query, dargs) + dc.Unlock() + if err != driver.ErrSkip { + if err != nil { + return nil, err + } + return driverResult{dc, resi}, nil + } + } + + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() + if err != nil { + return nil, err + } + defer withLock(dc, func() { si.Close() }) + return resultFromStatement(driverStmt{dc, si}, args...) +} + +// Query executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { + var rows *Rows + var err error + for i := 0; i < maxBadConnRetries; i++ { + rows, err = db.query(query, args) + if err != driver.ErrBadConn { + break + } + } + return rows, err +} + +func (db *DB) query(query string, args []interface{}) (*Rows, error) { + ci, err := db.conn() + if err != nil { + return nil, err + } + + return db.queryConn(ci, ci.releaseConn, query, args) +} + +// queryConn executes a query on the given connection. +// The connection gets released by the releaseConn function. +func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { + if queryer, ok := dc.ci.(driver.Queryer); ok { + dargs, err := driverArgs(nil, args) + if err != nil { + releaseConn(err) + return nil, err + } + dc.Lock() + rowsi, err := queryer.Query(query, dargs) + dc.Unlock() + if err != driver.ErrSkip { + if err != nil { + releaseConn(err) + return nil, err + } + // Note: ownership of dc passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + releaseConn: releaseConn, + rowsi: rowsi, + } + return rows, nil + } + } + + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() + if err != nil { + releaseConn(err) + return nil, err + } + + ds := driverStmt{dc, si} + rowsi, err := rowsiFromStatement(ds, args...) + if err != nil { + dc.Lock() + si.Close() + dc.Unlock() + releaseConn(err) + return nil, err + } + + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + releaseConn: releaseConn, + rowsi: rowsi, + closeStmt: si, + } + return rows, nil +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always return a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (db *DB) QueryRow(query string, args ...interface{}) *Row { + rows, err := db.Query(query, args...) + return &Row{rows: rows, err: err} +} + +// Begin starts a transaction. The isolation level is dependent on +// the driver. +func (db *DB) Begin() (*Tx, error) { + var tx *Tx + var err error + for i := 0; i < maxBadConnRetries; i++ { + tx, err = db.begin() + if err != driver.ErrBadConn { + break + } + } + return tx, err +} + +func (db *DB) begin() (tx *Tx, err error) { + dc, err := db.conn() + if err != nil { + return nil, err + } + dc.Lock() + txi, err := dc.ci.Begin() + dc.Unlock() + if err != nil { + db.putConn(dc, err) + return nil, err + } + return &Tx{ + db: db, + dc: dc, + txi: txi, + }, nil +} + +// Driver returns the database's underlying driver. +func (db *DB) Driver() driver.Driver { + return db.driver +} + +// Tx is an in-progress database transaction. +// +// A transaction must end with a call to Commit or Rollback. +// +// After a call to Commit or Rollback, all operations on the +// transaction fail with ErrTxDone. +type Tx struct { + db *DB + + // dc is owned exclusively until Commit or Rollback, at which point + // it's returned with putConn. + dc *driverConn + txi driver.Tx + + // done transitions from false to true exactly once, on Commit + // or Rollback. once done, all operations fail with + // ErrTxDone. + done bool + + // All Stmts prepared for this transaction. These will be closed after the + // transaction has been committed or rolled back. + stmts struct { + sync.Mutex + v []*Stmt + } +} + +var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") + +func (tx *Tx) close() { + if tx.done { + panic("double close") // internal error + } + tx.done = true + tx.db.putConn(tx.dc, nil) + tx.dc = nil + tx.txi = nil +} + +func (tx *Tx) grabConn() (*driverConn, error) { + if tx.done { + return nil, ErrTxDone + } + return tx.dc, nil +} + +// Closes all Stmts prepared for this transaction. +func (tx *Tx) closePrepared() { + tx.stmts.Lock() + for _, stmt := range tx.stmts.v { + stmt.Close() + } + tx.stmts.Unlock() +} + +// Commit commits the transaction. +func (tx *Tx) Commit() error { + if tx.done { + return ErrTxDone + } + defer tx.close() + tx.dc.Lock() + err := tx.txi.Commit() + tx.dc.Unlock() + if err != driver.ErrBadConn { + tx.closePrepared() + } + return err +} + +// Rollback aborts the transaction. +func (tx *Tx) Rollback() error { + if tx.done { + return ErrTxDone + } + defer tx.close() + tx.dc.Lock() + err := tx.txi.Rollback() + tx.dc.Unlock() + if err != driver.ErrBadConn { + tx.closePrepared() + } + return err +} + +// Prepare creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and can no longer +// be used once the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +func (tx *Tx) Prepare(query string) (*Stmt, error) { + // TODO(bradfitz): We could be more efficient here and either + // provide a method to take an existing Stmt (created on + // perhaps a different Conn), and re-create it on this Conn if + // necessary. Or, better: keep a map in DB of query string to + // Stmts, and have Stmt.Execute do the right thing and + // re-prepare if the Conn in use doesn't have that prepared + // statement. But we'll want to avoid caching the statement + // in the case where we only call conn.Prepare implicitly + // (such as in db.Exec or tx.Exec), but the caller package + // can't be holding a reference to the returned statement. + // Perhaps just looking at the reference count (by noting + // Stmt.Close) would be enough. We might also want a finalizer + // on Stmt to drop the reference count. + dc, err := tx.grabConn() + if err != nil { + return nil, err + } + + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() + if err != nil { + return nil, err + } + + stmt := &Stmt{ + db: tx.db, + tx: tx, + txsi: &driverStmt{ + Locker: dc, + si: si, + }, + query: query, + } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, stmt) + tx.stmts.Unlock() + return stmt, nil +} + +// Stmt returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + // TODO(bradfitz): optimize this. Currently this re-prepares + // each time. This is fine for now to illustrate the API but + // we should really cache already-prepared statements + // per-Conn. See also the big comment in Tx.Prepare. + + if tx.db != stmt.db { + return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} + } + dc, err := tx.grabConn() + if err != nil { + return &Stmt{stickyErr: err} + } + dc.Lock() + si, err := dc.ci.Prepare(stmt.query) + dc.Unlock() + txs := &Stmt{ + db: tx.db, + tx: tx, + txsi: &driverStmt{ + Locker: dc, + si: si, + }, + query: stmt.query, + stickyErr: err, + } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, txs) + tx.stmts.Unlock() + return txs +} + +// Exec executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { + dc, err := tx.grabConn() + if err != nil { + return nil, err + } + + if execer, ok := dc.ci.(driver.Execer); ok { + dargs, err := driverArgs(nil, args) + if err != nil { + return nil, err + } + dc.Lock() + resi, err := execer.Exec(query, dargs) + dc.Unlock() + if err == nil { + return driverResult{dc, resi}, nil + } + if err != driver.ErrSkip { + return nil, err + } + } + + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() + if err != nil { + return nil, err + } + defer withLock(dc, func() { si.Close() }) + + return resultFromStatement(driverStmt{dc, si}, args...) +} + +// Query executes a query that returns rows, typically a SELECT. +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { + dc, err := tx.grabConn() + if err != nil { + return nil, err + } + releaseConn := func(error) {} + return tx.db.queryConn(dc, releaseConn, query, args) +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always return a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { + rows, err := tx.Query(query, args...) + return &Row{rows: rows, err: err} +} + +// connStmt is a prepared statement on a particular connection. +type connStmt struct { + dc *driverConn + si driver.Stmt +} + +// Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines. +type Stmt struct { + // Immutable: + db *DB // where we came from + query string // that created the Stmt + stickyErr error // if non-nil, this error is returned for all operations + + closemu sync.RWMutex // held exclusively during close, for read otherwise. + + // If in a transaction, else both nil: + tx *Tx + txsi *driverStmt + + mu sync.Mutex // protects the rest of the fields + closed bool + + // css is a list of underlying driver statement interfaces + // that are valid on particular connections. This is only + // used if tx == nil and one is found that has idle + // connections. If tx != nil, txsi is always used. + css []connStmt +} + +// Exec executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) Exec(args ...interface{}) (Result, error) { + s.closemu.RLock() + defer s.closemu.RUnlock() + + var res Result + for i := 0; i < maxBadConnRetries; i++ { + dc, releaseConn, si, err := s.connStmt() + if err != nil { + if err == driver.ErrBadConn { + continue + } + return nil, err + } + + res, err = resultFromStatement(driverStmt{dc, si}, args...) + releaseConn(err) + if err != driver.ErrBadConn { + return res, err + } + } + return nil, driver.ErrBadConn +} + +func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { + ds.Lock() + want := ds.si.NumInput() + ds.Unlock() + + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want != -1 && len(args) != want { + return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) + } + + dargs, err := driverArgs(&ds, args) + if err != nil { + return nil, err + } + + ds.Lock() + resi, err := ds.si.Exec(dargs) + ds.Unlock() + if err != nil { + return nil, err + } + return driverResult{ds.Locker, resi}, nil +} + +// connStmt returns a free driver connection on which to execute the +// statement, a function to call to release the connection, and a +// statement bound to that connection. +func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { + if err = s.stickyErr; err != nil { + return + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + err = errors.New("sql: statement is closed") + return + } + + // In a transaction, we always use the connection that the + // transaction was created on. + if s.tx != nil { + s.mu.Unlock() + ci, err = s.tx.grabConn() // blocks, waiting for the connection. + if err != nil { + return + } + releaseConn = func(error) {} + return ci, releaseConn, s.txsi.si, nil + } + + for i := 0; i < len(s.css); i++ { + v := s.css[i] + _, err := s.db.connIfFree(v.dc) + if err == nil { + s.mu.Unlock() + return v.dc, v.dc.releaseConn, v.si, nil + } + if err == errConnClosed { + // Lazily remove dead conn from our freelist. + s.css[i] = s.css[len(s.css)-1] + s.css = s.css[:len(s.css)-1] + i-- + } + + } + s.mu.Unlock() + + // If all connections are busy, either wait for one to become available (if + // we've already hit the maximum number of open connections) or create a + // new one. + // + // TODO(bradfitz): or always wait for one? make configurable later? + dc, err := s.db.conn() + if err != nil { + return nil, nil, nil, err + } + + // Do another pass over the list to see whether this statement has + // already been prepared on the connection assigned to us. + s.mu.Lock() + for _, v := range s.css { + if v.dc == dc { + s.mu.Unlock() + return dc, dc.releaseConn, v.si, nil + } + } + s.mu.Unlock() + + // No luck; we need to prepare the statement on this connection + dc.Lock() + si, err = dc.prepareLocked(s.query) + dc.Unlock() + if err != nil { + s.db.putConn(dc, err) + return nil, nil, nil, err + } + s.mu.Lock() + cs := connStmt{dc, si} + s.css = append(s.css, cs) + s.mu.Unlock() + + return dc, dc.releaseConn, si, nil +} + +// Query executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) Query(args ...interface{}) (*Rows, error) { + s.closemu.RLock() + defer s.closemu.RUnlock() + + var rowsi driver.Rows + for i := 0; i < maxBadConnRetries; i++ { + dc, releaseConn, si, err := s.connStmt() + if err != nil { + if err == driver.ErrBadConn { + continue + } + return nil, err + } + + rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...) + if err == nil { + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + rowsi: rowsi, + // releaseConn set below + } + s.db.addDep(s, rows) + rows.releaseConn = func(err error) { + releaseConn(err) + s.db.removeDep(s, rows) + } + return rows, nil + } + + releaseConn(err) + if err != driver.ErrBadConn { + return nil, err + } + } + return nil, driver.ErrBadConn +} + +func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) { + ds.Lock() + want := ds.si.NumInput() + ds.Unlock() + + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want != -1 && len(args) != want { + return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) + } + + dargs, err := driverArgs(&ds, args) + if err != nil { + return nil, err + } + + ds.Lock() + rowsi, err := ds.si.Query(dargs) + ds.Unlock() + if err != nil { + return nil, err + } + return rowsi, nil +} + +// QueryRow executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// Example usage: +// +// var name string +// err := nameByUseridStmt.QueryRow(id).Scan(&name) +func (s *Stmt) QueryRow(args ...interface{}) *Row { + rows, err := s.Query(args...) + if err != nil { + return &Row{err: err} + } + return &Row{rows: rows} +} + +// Close closes the statement. +func (s *Stmt) Close() error { + s.closemu.Lock() + defer s.closemu.Unlock() + + if s.stickyErr != nil { + return s.stickyErr + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return nil + } + s.closed = true + + if s.tx != nil { + s.txsi.Close() + s.mu.Unlock() + return nil + } + s.mu.Unlock() + + return s.db.removeDep(s, s) +} + +func (s *Stmt) finalClose() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.css != nil { + for _, v := range s.css { + s.db.noteUnusedDriverStatement(v.dc, v.si) + v.dc.removeOpenStmt(v.si) + } + s.css = nil + } + return nil +} + +// Rows is the result of a query. Its cursor starts before the first row +// of the result set. Use Next to advance through the rows: +// +// rows, err := db.Query("SELECT ...") +// ... +// defer rows.Close() +// for rows.Next() { +// var id int +// var name string +// err = rows.Scan(&id, &name) +// ... +// } +// err = rows.Err() // get any error encountered during iteration +// ... +type Rows struct { + dc *driverConn // owned; must call releaseConn when closed to release + releaseConn func(error) + rowsi driver.Rows + + closed bool + lastcols []driver.Value + lasterr error // non-nil only if closed is true + closeStmt driver.Stmt // if non-nil, statement to Close on close +} + +// Next prepares the next result row for reading with the Scan method. It +// returns true on success, or false if there is no next result row or an error +// happened while preparing it. Err should be consulted to distinguish between +// the two cases. +// +// Every call to Scan, even the first one, must be preceded by a call to Next. +func (rs *Rows) Next() bool { + if rs.closed { + return false + } + if rs.lastcols == nil { + rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) + } + rs.lasterr = rs.rowsi.Next(rs.lastcols) + if rs.lasterr != nil { + rs.Close() + return false + } + return true +} + +// Err returns the error, if any, that was encountered during iteration. +// Err may be called after an explicit or implicit Close. +func (rs *Rows) Err() error { + if rs.lasterr == io.EOF { + return nil + } + return rs.lasterr +} + +// Columns returns the column names. +// Columns returns an error if the rows are closed, or if the rows +// are from QueryRow and there was a deferred error. +func (rs *Rows) Columns() ([]string, error) { + if rs.closed { + return nil, errors.New("sql: Rows are closed") + } + if rs.rowsi == nil { + return nil, errors.New("sql: no Rows available") + } + return rs.rowsi.Columns(), nil +} + +// Scan copies the columns in the current row into the values pointed +// at by dest. +// +// If an argument has type *[]byte, Scan saves in that argument a copy +// of the corresponding data. The copy is owned by the caller and can +// be modified and held indefinitely. The copy can be avoided by using +// an argument of type *RawBytes instead; see the documentation for +// RawBytes for restrictions on its use. +// +// If an argument has type *interface{}, Scan copies the value +// provided by the underlying driver without conversion. If the value +// is of type []byte, a copy is made and the caller owns the result. +func (rs *Rows) Scan(dest ...interface{}) error { + if rs.closed { + return errors.New("sql: Rows are closed") + } + if rs.lastcols == nil { + return errors.New("sql: Scan called without calling Next") + } + if len(dest) != len(rs.lastcols) { + return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) + } + for i, sv := range rs.lastcols { + err := convertAssign(dest[i], sv) + if err != nil { + return fmt.Errorf("sql: Scan error on column index %d: %v", i, err) + } + } + return nil +} + +var rowsCloseHook func(*Rows, *error) + +// Close closes the Rows, preventing further enumeration. If Next returns +// false, the Rows are closed automatically and it will suffice to check the +// result of Err. Close is idempotent and does not affect the result of Err. +func (rs *Rows) Close() error { + if rs.closed { + return nil + } + rs.closed = true + err := rs.rowsi.Close() + if fn := rowsCloseHook; fn != nil { + fn(rs, &err) + } + if rs.closeStmt != nil { + rs.closeStmt.Close() + } + rs.releaseConn(err) + return err +} + +// Row is the result of calling QueryRow to select a single row. +type Row struct { + // One of these two will be non-nil: + err error // deferred error for easy chaining + rows *Rows +} + +// Scan copies the columns from the matched row into the values +// pointed at by dest. If more than one row matches the query, +// Scan uses the first row and discards the rest. If no row matches +// the query, Scan returns ErrNoRows. +func (r *Row) Scan(dest ...interface{}) error { + if r.err != nil { + return r.err + } + + // TODO(bradfitz): for now we need to defensively clone all + // []byte that the driver returned (not permitting + // *RawBytes in Rows.Scan), since we're about to close + // the Rows in our defer, when we return from this function. + // the contract with the driver.Next(...) interface is that it + // can return slices into read-only temporary memory that's + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We + // should provide an optional interface for drivers to + // implement to say, "don't worry, the []bytes that I return + // from Next will not be modified again." (for instance, if + // they were obtained from the network anyway) But for now we + // don't care. + defer r.rows.Close() + for _, dp := range dest { + if _, ok := dp.(*RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !r.rows.Next() { + if err := r.rows.Err(); err != nil { + return err + } + return ErrNoRows + } + err := r.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + if err := r.rows.Close(); err != nil { + return err + } + + return nil +} + +// A Result summarizes an executed SQL command. +type Result interface { + // LastInsertId returns the integer generated by the database + // in response to a command. Typically this will be from an + // "auto increment" column when inserting a new row. Not all + // databases support this feature, and the syntax of such + // statements varies. + LastInsertId() (int64, error) + + // RowsAffected returns the number of rows affected by an + // update, insert, or delete. Not every database or database + // driver may support this. + RowsAffected() (int64, error) +} + +type driverResult struct { + sync.Locker // the *driverConn + resi driver.Result +} + +func (dr driverResult) LastInsertId() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.LastInsertId() +} + +func (dr driverResult) RowsAffected() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.RowsAffected() +} + +func stack() string { + var buf [2 << 10]byte + return string(buf[:runtime.Stack(buf[:], false)]) +} + +// withLock runs while holding lk. +func withLock(lk sync.Locker, fn func()) { + lk.Lock() + fn() + lk.Unlock() +} diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go new file mode 100644 index 000000000..34efdf254 --- /dev/null +++ b/src/database/sql/sql_test.go @@ -0,0 +1,1987 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "database/sql/driver" + "errors" + "fmt" + "math/rand" + "reflect" + "runtime" + "strings" + "sync" + "testing" + "time" +) + +func init() { + type dbConn struct { + db *DB + c *driverConn + } + freedFrom := make(map[dbConn]string) + putConnHook = func(db *DB, c *driverConn) { + idx := -1 + for i, v := range db.freeConn { + if v == c { + idx = i + break + } + } + if idx >= 0 { + // print before panic, as panic may get lost due to conflicting panic + // (all goroutines asleep) elsewhere, since we might not unlock + // the mutex in freeConn here. + println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack()) + panic("double free of conn.") + } + freedFrom[dbConn{db, c}] = stack() + } +} + +const fakeDBName = "foo" + +var chrisBirthday = time.Unix(123456789, 0) + +func newTestDB(t testing.TB, name string) *DB { + db, err := Open("test", fakeDBName) + if err != nil { + t.Fatalf("Open: %v", err) + } + if _, err := db.Exec("WIPE"); err != nil { + t.Fatalf("exec wipe: %v", err) + } + if name == "people" { + exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime") + exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1) + exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2) + exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) + } + if name == "magicquery" { + // Magic table name and column, known by fakedb_test.go. + exec(t, db, "CREATE|magicquery|op=string,millis=int32") + exec(t, db, "INSERT|magicquery|op=sleep,millis=10") + } + return db +} + +func exec(t testing.TB, db *DB, query string, args ...interface{}) { + _, err := db.Exec(query, args...) + if err != nil { + t.Fatalf("Exec of %q: %v", query, err) + } +} + +func closeDB(t testing.TB, db *DB) { + if e := recover(); e != nil { + fmt.Printf("Panic: %v\n", e) + panic(e) + } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + for i, dc := range db.freeConn { + if n := len(dc.openStmt); n > 0 { + // Just a sanity check. This is legal in + // general, but if we make the tests clean up + // their statements first, then we can safely + // verify this is always zero here, and any + // other value is a leak. + t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n) + } + } + err := db.Close() + if err != nil { + t.Fatalf("error closing DB: %v", err) + } + db.mu.Lock() + count := db.numOpen + db.mu.Unlock() + if count != 0 { + t.Fatalf("%d connections still open after closing DB", db.numOpen) + } +} + +// numPrepares assumes that db has exactly 1 idle conn and returns +// its count of calls to Prepare +func numPrepares(t *testing.T, db *DB) int { + if n := len(db.freeConn); n != 1 { + t.Fatalf("free conns = %d; want 1", n) + } + return db.freeConn[0].ci.(*fakeConn).numPrepare +} + +func (db *DB) numDeps() int { + db.mu.Lock() + defer db.mu.Unlock() + return len(db.dep) +} + +// Dependencies are closed via a goroutine, so this polls waiting for +// numDeps to fall to want, waiting up to d. +func (db *DB) numDepsPollUntil(want int, d time.Duration) int { + deadline := time.Now().Add(d) + for { + n := db.numDeps() + if n <= want || time.Now().After(deadline) { + return n + } + time.Sleep(50 * time.Millisecond) + } +} + +func (db *DB) numFreeConns() int { + db.mu.Lock() + defer db.mu.Unlock() + return len(db.freeConn) +} + +func (db *DB) dumpDeps(t *testing.T) { + for fc := range db.dep { + db.dumpDep(t, 0, fc, map[finalCloser]bool{}) + } +} + +func (db *DB) dumpDep(t *testing.T, depth int, dep finalCloser, seen map[finalCloser]bool) { + seen[dep] = true + indent := strings.Repeat(" ", depth) + ds := db.dep[dep] + for k := range ds { + t.Logf("%s%T (%p) waiting for -> %T (%p)", indent, dep, dep, k, k) + if fc, ok := k.(finalCloser); ok { + if !seen[fc] { + db.dumpDep(t, depth+1, fc, seen) + } + } + } +} + +func TestQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + {age: 3, name: "Chris"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestByteOwnership(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|name,photo|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + name []byte + photo RawBytes + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.name, &r.photo) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + corruptMemory := []byte("\xffPHOTO") + want := []row{ + {name: []byte("Alice"), photo: corruptMemory}, + {name: []byte("Bob"), photo: corruptMemory}, + {name: []byte("Chris"), photo: corruptMemory}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + var photo RawBytes + err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo) + if err == nil { + t.Error("want error scanning into RawBytes from QueryRow") + } +} + +func TestRowsColumns(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + cols, err := rows.Columns() + if err != nil { + t.Fatalf("Columns: %v", err) + } + want := []string{"age", "name"} + if !reflect.DeepEqual(cols, want) { + t.Errorf("got %#v; want %#v", cols, want) + } + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } +} + +func TestQueryRow(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + var name string + var age int + var birthday time.Time + + err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age) + if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") { + t.Errorf("expected error from wrong number of arguments; actually got: %v", err) + } + + err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday) + if err != nil || !birthday.Equal(chrisBirthday) { + t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday) + } + + err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name) + if err != nil { + t.Fatalf("age QueryRow+Scan: %v", err) + } + if name != "Bob" { + t.Errorf("expected name Bob, got %q", name) + } + if age != 2 { + t.Errorf("expected age 2, got %d", age) + } + + err = db.QueryRow("SELECT|people|age,name|name=?", "Alice").Scan(&age, &name) + if err != nil { + t.Fatalf("name QueryRow+Scan: %v", err) + } + if name != "Alice" { + t.Errorf("expected name Alice, got %q", name) + } + if age != 1 { + t.Errorf("expected age 1, got %d", age) + } + + var photo []byte + err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo) + if err != nil { + t.Fatalf("photo QueryRow+Scan: %v", err) + } + want := []byte("APHOTO") + if !reflect.DeepEqual(photo, want) { + t.Errorf("photo = %q; want %q", photo, want) + } +} + +func TestStatementErrorAfterClose(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + err = stmt.Close() + if err != nil { + t.Fatalf("Close: %v", err) + } + var name string + err = stmt.QueryRow("foo").Scan(&name) + if err == nil { + t.Errorf("expected error from QueryRow.Scan after Stmt.Close") + } +} + +func TestStatementQueryRow(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + defer stmt.Close() + var age int + for n, tt := range []struct { + name string + want int + }{ + {"Alice", 1}, + {"Bob", 2}, + {"Chris", 3}, + } { + if err := stmt.QueryRow(tt.name).Scan(&age); err != nil { + t.Errorf("%d: on %q, QueryRow/Scan: %v", n, tt.name, err) + } else if age != tt.want { + t.Errorf("%d: age=%d, want %d", n, age, tt.want) + } + } +} + +// golang.org/issue/3734 +func TestStatementQueryRowConcurrent(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + defer stmt.Close() + + const n = 10 + ch := make(chan error, n) + for i := 0; i < n; i++ { + go func() { + var age int + err := stmt.QueryRow("Alice").Scan(&age) + if err == nil && age != 1 { + err = fmt.Errorf("unexpected age %d", age) + } + ch <- err + }() + } + for i := 0; i < n; i++ { + if err := <-ch; err != nil { + t.Error(err) + } + } +} + +// just a test of fakedb itself +func TestBogusPreboundParameters(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion") + if err == nil { + t.Fatalf("expected error") + } + if err.Error() != `fakedb: invalid conversion to int32 from "bogusconversion"` { + t.Errorf("unexpected error: %v", err) + } +} + +func TestExec(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Errorf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + + type execTest struct { + args []interface{} + wantErr string + } + execTests := []execTest{ + // Okay: + {[]interface{}{"Brad", 31}, ""}, + {[]interface{}{"Brad", int64(31)}, ""}, + {[]interface{}{"Bob", "32"}, ""}, + {[]interface{}{7, 9}, ""}, + + // Invalid conversions: + {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument #1's type: sql/driver: value 4294967295 overflows int32"}, + {[]interface{}{"Brad", "strconv fail"}, "sql: converting argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"}, + + // Wrong number of args: + {[]interface{}{}, "sql: expected 2 arguments, got 0"}, + {[]interface{}{1, 2, 3}, "sql: expected 2 arguments, got 3"}, + } + for n, et := range execTests { + _, err := stmt.Exec(et.args...) + errStr := "" + if err != nil { + errStr = err.Error() + } + if errStr != et.wantErr { + t.Errorf("stmt.Execute #%d: for %v, got error %q, want error %q", + n, et.args, errStr, et.wantErr) + } + } +} + +func TestTxPrepare(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + stmt, err := tx.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + _, err = stmt.Exec("Bobby", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } + // Commit() should have closed the statement + if !stmt.closed { + t.Fatal("Stmt not closed after Commit") + } +} + +func TestTxStmt(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs := tx.Stmt(stmt) + defer txs.Close() + _, err = txs.Exec("Bobby", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } + // Commit() should have closed the statement + if !txs.closed { + t.Fatal("Stmt not closed after Commit") + } +} + +// Issue: http://golang.org/issue/2784 +// This test didn't fail before because we got lucky with the fakedb driver. +// It was failing, and now not, in github.com/bradfitz/go-sql-test +func TestTxQuery(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + exec(t, db, "INSERT|t1|name=Alice") + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + r, err := tx.Query("SELECT|t1|name|") + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(r.Err()) + } + t.Fatal("expected one row") + } + + var x string + err = r.Scan(&x) + if err != nil { + t.Fatal(err) + } +} + +func TestTxQueryInvalid(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + _, err = tx.Query("SELECT|t1|name|") + if err == nil { + t.Fatal("Error expected") + } +} + +// Tests fix for issue 4433, that retries in Begin happen when +// conn.Begin() returns ErrBadConn +func TestTxErrBadConn(t *testing.T) { + db, err := Open("test", fakeDBName+";badConn") + if err != nil { + t.Fatalf("Open: %v", err) + } + if _, err := db.Exec("WIPE"); err != nil { + t.Fatalf("exec wipe: %v", err) + } + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs := tx.Stmt(stmt) + defer txs.Close() + _, err = txs.Exec("Bobby", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } +} + +// Tests fix for issue 2542, that we release a lock when querying on +// a closed connection. +func TestIssue2542Deadlock(t *testing.T) { + db := newTestDB(t, "people") + closeDB(t, db) + for i := 0; i < 2; i++ { + _, err := db.Query("SELECT|people|age,name|") + if err == nil { + t.Fatalf("expected error") + } + } +} + +// From golang.org/issue/3865 +func TestCloseStmtBeforeRows(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + s, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + + r, err := s.Query() + if err != nil { + s.Close() + t.Fatal(err) + } + + err = s.Close() + if err != nil { + t.Fatal(err) + } + + r.Close() +} + +// Tests fix for issue 2788, that we bind nil to a []byte if the +// value in the column is sql null +func TestNullByteSlice(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t|id=int32,name=nullstring") + exec(t, db, "INSERT|t|id=10,name=?", nil) + + var name []byte + + err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name) + if err != nil { + t.Fatal(err) + } + if name != nil { + t.Fatalf("name []byte should be nil for null column value, got: %#v", name) + } + + exec(t, db, "INSERT|t|id=11,name=?", "bob") + err = db.QueryRow("SELECT|t|name|id=?", 11).Scan(&name) + if err != nil { + t.Fatal(err) + } + if string(name) != "bob" { + t.Fatalf("name []byte should be bob, got: %q", string(name)) + } +} + +func TestPointerParamsAndScans(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t|id=int32,name=nullstring") + + bob := "bob" + var name *string + + name = &bob + exec(t, db, "INSERT|t|id=10,name=?", name) + name = nil + exec(t, db, "INSERT|t|id=20,name=?", name) + + err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name) + if err != nil { + t.Fatalf("querying id 10: %v", err) + } + if name == nil { + t.Errorf("id 10's name = nil; want bob") + } else if *name != "bob" { + t.Errorf("id 10's name = %q; want bob", *name) + } + + err = db.QueryRow("SELECT|t|name|id=?", 20).Scan(&name) + if err != nil { + t.Fatalf("querying id 20: %v", err) + } + if name != nil { + t.Errorf("id 20 = %q; want nil", *name) + } +} + +func TestQueryRowClosingStmt(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + var name string + var age int + err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name) + if err != nil { + t.Fatal(err) + } + if len(db.freeConn) != 1 { + t.Fatalf("expected 1 free conn") + } + fakeConn := db.freeConn[0].ci.(*fakeConn) + if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed { + t.Errorf("statement close mismatch: made %d, closed %d", made, closed) + } +} + +// Test issue 6651 +func TestIssue6651(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + var v string + + want := "error in rows.Next" + rowsCursorNextHook = func(dest []driver.Value) error { + return fmt.Errorf(want) + } + defer func() { rowsCursorNextHook = nil }() + err := db.QueryRow("SELECT|people|name|").Scan(&v) + if err == nil || err.Error() != want { + t.Errorf("error = %q; want %q", err, want) + } + rowsCursorNextHook = nil + + want = "error in rows.Close" + rowsCloseHook = func(rows *Rows, err *error) { + *err = fmt.Errorf(want) + } + defer func() { rowsCloseHook = nil }() + err = db.QueryRow("SELECT|people|name|").Scan(&v) + if err == nil || err.Error() != want { + t.Errorf("error = %q; want %q", err, want) + } +} + +type nullTestRow struct { + nullParam interface{} + notNullParam interface{} + scanNullVal interface{} +} + +type nullTestSpec struct { + nullType string + notNullType string + rows [6]nullTestRow +} + +func TestNullStringParam(t *testing.T) { + spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{ + {NullString{"aqua", true}, "", NullString{"aqua", true}}, + {NullString{"brown", false}, "", NullString{"", false}}, + {"chartreuse", "", NullString{"chartreuse", true}}, + {NullString{"darkred", true}, "", NullString{"darkred", true}}, + {NullString{"eel", false}, "", NullString{"", false}}, + {"foo", NullString{"black", false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullInt64Param(t *testing.T) { + spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{ + {NullInt64{31, true}, 1, NullInt64{31, true}}, + {NullInt64{-22, false}, 1, NullInt64{0, false}}, + {22, 1, NullInt64{22, true}}, + {NullInt64{33, true}, 1, NullInt64{33, true}}, + {NullInt64{222, false}, 1, NullInt64{0, false}}, + {0, NullInt64{31, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullFloat64Param(t *testing.T) { + spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{ + {NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}}, + {NullFloat64{13.1, false}, 1, NullFloat64{0, false}}, + {-22.9, 1, NullFloat64{-22.9, true}}, + {NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}}, + {NullFloat64{222, false}, 1, NullFloat64{0, false}}, + {10, NullFloat64{31.2, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullBoolParam(t *testing.T) { + spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{ + {NullBool{false, true}, true, NullBool{false, true}}, + {NullBool{true, false}, false, NullBool{false, false}}, + {true, true, NullBool{true, true}}, + {NullBool{true, true}, false, NullBool{true, true}}, + {NullBool{true, false}, true, NullBool{false, false}}, + {true, NullBool{true, false}, nil}, + }} + nullTestRun(t, spec) +} + +func nullTestRun(t *testing.T, spec nullTestSpec) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, fmt.Sprintf("CREATE|t|id=int32,name=string,nullf=%s,notnullf=%s", spec.nullType, spec.notNullType)) + + // Inserts with db.Exec: + exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 1, "alice", spec.rows[0].nullParam, spec.rows[0].notNullParam) + exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 2, "bob", spec.rows[1].nullParam, spec.rows[1].notNullParam) + + // Inserts with a prepared statement: + stmt, err := db.Prepare("INSERT|t|id=?,name=?,nullf=?,notnullf=?") + if err != nil { + t.Fatalf("prepare: %v", err) + } + defer stmt.Close() + if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil { + t.Errorf("exec insert chris: %v", err) + } + if _, err := stmt.Exec(4, "dave", spec.rows[3].nullParam, spec.rows[3].notNullParam); err != nil { + t.Errorf("exec insert dave: %v", err) + } + if _, err := stmt.Exec(5, "eleanor", spec.rows[4].nullParam, spec.rows[4].notNullParam); err != nil { + t.Errorf("exec insert eleanor: %v", err) + } + + // Can't put null val into non-null col + if _, err := stmt.Exec(6, "bob", spec.rows[5].nullParam, spec.rows[5].notNullParam); err == nil { + t.Errorf("expected error inserting nil val with prepared statement Exec") + } + + _, err = db.Exec("INSERT|t|id=?,name=?,nullf=?", 999, nil, nil) + if err == nil { + // TODO: this test fails, but it's just because + // fakeConn implements the optional Execer interface, + // so arguably this is the correct behavior. But + // maybe I should flesh out the fakeConn.Exec + // implementation so this properly fails. + // t.Errorf("expected error inserting nil name with Exec") + } + + paramtype := reflect.TypeOf(spec.rows[0].nullParam) + bindVal := reflect.New(paramtype).Interface() + + for i := 0; i < 5; i++ { + id := i + 1 + if err := db.QueryRow("SELECT|t|nullf|id=?", id).Scan(bindVal); err != nil { + t.Errorf("id=%d Scan: %v", id, err) + } + bindValDeref := reflect.ValueOf(bindVal).Elem().Interface() + if !reflect.DeepEqual(bindValDeref, spec.rows[i].scanNullVal) { + t.Errorf("id=%d got %#v, want %#v", id, bindValDeref, spec.rows[i].scanNullVal) + } + } +} + +// golang.org/issue/4859 +func TestQueryRowNilScanDest(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + var name *string // nil pointer + err := db.QueryRow("SELECT|people|name|").Scan(name) + want := "sql: Scan error on column index 0: destination pointer is nil" + if err == nil || err.Error() != want { + t.Errorf("error = %q; want %q", err.Error(), want) + } +} + +func TestIssue4902(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + driver := db.driver.(*fakeDriver) + opens0 := driver.openCount + + var stmt *Stmt + var err error + for i := 0; i < 10; i++ { + stmt, err = db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + } + + opens := driver.openCount - opens0 + if opens > 1 { + t.Errorf("opens = %d; want <= 1", opens) + t.Logf("db = %#v", db) + t.Logf("driver = %#v", driver) + t.Logf("stmt = %#v", stmt) + } +} + +// Issue 3857 +// This used to deadlock. +func TestSimultaneousQueries(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + r1, err := tx.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer r1.Close() + + r2, err := tx.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer r2.Close() +} + +func TestMaxIdleConns(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + tx.Commit() + if got := len(db.freeConn); got != 1 { + t.Errorf("freeConns = %d; want 1", got) + } + + db.SetMaxIdleConns(0) + + if got := len(db.freeConn); got != 0 { + t.Errorf("freeConns after set to zero = %d; want 0", got) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + tx.Commit() + if got := len(db.freeConn); got != 0 { + t.Errorf("freeConns = %d; want 0", got) + } +} + +func TestMaxOpenConns(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + driver := db.driver.(*fakeDriver) + + // Force the number of open connections to 0 so we can get an accurate + // count for the test + db.SetMaxIdleConns(0) + + if g, w := db.numFreeConns(), 0; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(0, time.Second); n > 0 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } + + driver.mu.Lock() + opens0 := driver.openCount + closes0 := driver.closeCount + driver.mu.Unlock() + + db.SetMaxIdleConns(10) + db.SetMaxOpenConns(10) + + stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?") + if err != nil { + t.Fatal(err) + } + + // Start 50 parallel slow queries. + const ( + nquery = 50 + sleepMillis = 25 + nbatch = 2 + ) + var wg sync.WaitGroup + for batch := 0; batch < nbatch; batch++ { + for i := 0; i < nquery; i++ { + wg.Add(1) + go func() { + defer wg.Done() + var op string + if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows { + t.Error(err) + } + }() + } + // Sleep for twice the expected length of time for the + // batch of 50 queries above to finish before starting + // the next round. + time.Sleep(2 * sleepMillis * time.Millisecond) + } + wg.Wait() + + if g, w := db.numFreeConns(), 10; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(20, time.Second); n > 20 { + t.Errorf("number of dependencies = %d; expected <= 20", n) + db.dumpDeps(t) + } + + driver.mu.Lock() + opens := driver.openCount - opens0 + closes := driver.closeCount - closes0 + driver.mu.Unlock() + + if opens > 10 { + t.Logf("open calls = %d", opens) + t.Logf("close calls = %d", closes) + t.Errorf("db connections opened = %d; want <= 10", opens) + db.dumpDeps(t) + } + + if err := stmt.Close(); err != nil { + t.Fatal(err) + } + + if g, w := db.numFreeConns(), 10; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(10, time.Second); n > 10 { + t.Errorf("number of dependencies = %d; expected <= 10", n) + db.dumpDeps(t) + } + + db.SetMaxOpenConns(5) + + if g, w := db.numFreeConns(), 5; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(5, time.Second); n > 5 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } + + db.SetMaxOpenConns(0) + + if g, w := db.numFreeConns(), 5; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(5, time.Second); n > 5 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } + + db.SetMaxIdleConns(0) + + if g, w := db.numFreeConns(), 0; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(0, time.Second); n > 0 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } +} + +func TestSingleOpenConn(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxOpenConns(1) + + rows, err := db.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } + // shouldn't deadlock + rows, err = db.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } +} + +// golang.org/issue/5323 +func TestStmtCloseDeps(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + driver := db.driver.(*fakeDriver) + + driver.mu.Lock() + opens0 := driver.openCount + closes0 := driver.closeCount + driver.mu.Unlock() + openDelta0 := opens0 - closes0 + + stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?") + if err != nil { + t.Fatal(err) + } + + // Start 50 parallel slow queries. + const ( + nquery = 50 + sleepMillis = 25 + nbatch = 2 + ) + var wg sync.WaitGroup + for batch := 0; batch < nbatch; batch++ { + for i := 0; i < nquery; i++ { + wg.Add(1) + go func() { + defer wg.Done() + var op string + if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows { + t.Error(err) + } + }() + } + // Sleep for twice the expected length of time for the + // batch of 50 queries above to finish before starting + // the next round. + time.Sleep(2 * sleepMillis * time.Millisecond) + } + wg.Wait() + + if g, w := db.numFreeConns(), 2; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(4, time.Second); n > 4 { + t.Errorf("number of dependencies = %d; expected <= 4", n) + db.dumpDeps(t) + } + + driver.mu.Lock() + opens := driver.openCount - opens0 + closes := driver.closeCount - closes0 + openDelta := (driver.openCount - driver.closeCount) - openDelta0 + driver.mu.Unlock() + + if openDelta > 2 { + t.Logf("open calls = %d", opens) + t.Logf("close calls = %d", closes) + t.Logf("open delta = %d", openDelta) + t.Errorf("db connections opened = %d; want <= 2", openDelta) + db.dumpDeps(t) + } + + if len(stmt.css) > nquery { + t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery) + } + + if err := stmt.Close(); err != nil { + t.Fatal(err) + } + + if g, w := db.numFreeConns(), 2; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(2, time.Second); n > 2 { + t.Errorf("number of dependencies = %d; expected <= 2", n) + db.dumpDeps(t) + } + + db.SetMaxIdleConns(0) + + if g, w := db.numFreeConns(), 0; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(0, time.Second); n > 0 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } +} + +// golang.org/issue/5046 +func TestCloseConnBeforeStmts(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v; from %s", err, stack()) + db.dumpDeps(t) + t.Errorf("DB = %#v", db) + } + }) + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + + if len(db.freeConn) != 1 { + t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn)) + } + dc := db.freeConn[0] + if dc.closed { + t.Errorf("conn shouldn't be closed") + } + + if n := len(dc.openStmt); n != 1 { + t.Errorf("driverConn num openStmt = %d; want 1", n) + } + err = db.Close() + if err != nil { + t.Errorf("db Close = %v", err) + } + if !dc.closed { + t.Errorf("after db.Close, driverConn should be closed") + } + if n := len(dc.openStmt); n != 0 { + t.Errorf("driverConn num openStmt = %d; want 0", n) + } + + err = stmt.Close() + if err != nil { + t.Errorf("Stmt close = %v", err) + } + + if !dc.closed { + t.Errorf("conn should be closed") + } + if dc.ci != nil { + t.Errorf("after Stmt Close, driverConn's Conn interface should be nil") + } +} + +// golang.org/issue/5283: don't release the Rows' connection in Close +// before calling Stmt.Close. +func TestRowsCloseOrder(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxIdleConns(0) + setStrictFakeConnClose(t) + defer setStrictFakeConnClose(nil) + + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + err = rows.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestRowsImplicitClose(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + + want, fail := 2, errors.New("fail") + r := rows.rowsi.(*rowsCursor) + r.errPos, r.err = want, fail + + got := 0 + for rows.Next() { + got++ + } + if got != want { + t.Errorf("got %d rows, want %d", got, want) + } + if err := rows.Err(); err != fail { + t.Errorf("got error %v, want %v", err, fail) + } + if !r.closed { + t.Errorf("r.closed is false, want true") + } +} + +func TestStmtCloseOrder(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxIdleConns(0) + setStrictFakeConnClose(t) + defer setStrictFakeConnClose(nil) + + _, err := db.Query("SELECT|non_existent|name|") + if err == nil { + t.Fatal("Quering non-existent table should fail") + } +} + +// golang.org/issue/5781 +func TestErrBadConnReconnect(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + + simulateBadConn := func(name string, hook *func() bool, op func() error) { + broken, retried := false, false + numOpen := db.numOpen + + // simulate a broken connection on the first try + *hook = func() bool { + if !broken { + broken = true + return true + } + retried = true + return false + } + + if err := op(); err != nil { + t.Errorf(name+": %v", err) + return + } + + if !broken || !retried { + t.Error(name + ": Failed to simulate broken connection") + } + *hook = nil + + if numOpen != db.numOpen { + t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen) + numOpen = db.numOpen + } + } + + // db.Exec + dbExec := func() error { + _, err := db.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true) + return err + } + simulateBadConn("db.Exec prepare", &hookPrepareBadConn, dbExec) + simulateBadConn("db.Exec exec", &hookExecBadConn, dbExec) + + // db.Query + dbQuery := func() error { + rows, err := db.Query("SELECT|t1|age,name|") + if err == nil { + err = rows.Close() + } + return err + } + simulateBadConn("db.Query prepare", &hookPrepareBadConn, dbQuery) + simulateBadConn("db.Query query", &hookQueryBadConn, dbQuery) + + // db.Prepare + simulateBadConn("db.Prepare", &hookPrepareBadConn, func() error { + stmt, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?") + if err != nil { + return err + } + stmt.Close() + return nil + }) + + // Provide a way to force a re-prepare of a statement on next execution + forcePrepare := func(stmt *Stmt) { + stmt.css = nil + } + + // stmt.Exec + stmt1, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?") + if err != nil { + t.Fatalf("prepare: %v", err) + } + defer stmt1.Close() + // make sure we must prepare the stmt first + forcePrepare(stmt1) + + stmtExec := func() error { + _, err := stmt1.Exec("Gopher", 3, false) + return err + } + simulateBadConn("stmt.Exec prepare", &hookPrepareBadConn, stmtExec) + simulateBadConn("stmt.Exec exec", &hookExecBadConn, stmtExec) + + // stmt.Query + stmt2, err := db.Prepare("SELECT|t1|age,name|") + if err != nil { + t.Fatalf("prepare: %v", err) + } + defer stmt2.Close() + // make sure we must prepare the stmt first + forcePrepare(stmt2) + + stmtQuery := func() error { + rows, err := stmt2.Query() + if err == nil { + err = rows.Close() + } + return err + } + simulateBadConn("stmt.Query prepare", &hookPrepareBadConn, stmtQuery) + simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery) +} + +type concurrentTest interface { + init(t testing.TB, db *DB) + finish(t testing.TB) + test(t testing.TB) error +} + +type concurrentDBQueryTest struct { + db *DB +} + +func (c *concurrentDBQueryTest) init(t testing.TB, db *DB) { + c.db = db +} + +func (c *concurrentDBQueryTest) finish(t testing.TB) { + c.db = nil +} + +func (c *concurrentDBQueryTest) test(t testing.TB) error { + rows, err := c.db.Query("SELECT|people|name|") + if err != nil { + t.Error(err) + return err + } + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentDBExecTest struct { + db *DB +} + +func (c *concurrentDBExecTest) init(t testing.TB, db *DB) { + c.db = db +} + +func (c *concurrentDBExecTest) finish(t testing.TB) { + c.db = nil +} + +func (c *concurrentDBExecTest) test(t testing.TB) error { + _, err := c.db.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) + if err != nil { + t.Error(err) + return err + } + return nil +} + +type concurrentStmtQueryTest struct { + db *DB + stmt *Stmt +} + +func (c *concurrentStmtQueryTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.stmt, err = db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentStmtQueryTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + c.db = nil +} + +func (c *concurrentStmtQueryTest) test(t testing.TB) error { + rows, err := c.stmt.Query() + if err != nil { + t.Errorf("error on query: %v", err) + return err + } + + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentStmtExecTest struct { + db *DB + stmt *Stmt +} + +func (c *concurrentStmtExecTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.stmt, err = db.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentStmtExecTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + c.db = nil +} + +func (c *concurrentStmtExecTest) test(t testing.TB) error { + _, err := c.stmt.Exec(3, chrisBirthday) + if err != nil { + t.Errorf("error on exec: %v", err) + return err + } + return nil +} + +type concurrentTxQueryTest struct { + db *DB + tx *Tx +} + +func (c *concurrentTxQueryTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxQueryTest) finish(t testing.TB) { + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxQueryTest) test(t testing.TB) error { + rows, err := c.db.Query("SELECT|people|name|") + if err != nil { + t.Error(err) + return err + } + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentTxExecTest struct { + db *DB + tx *Tx +} + +func (c *concurrentTxExecTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxExecTest) finish(t testing.TB) { + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxExecTest) test(t testing.TB) error { + _, err := c.tx.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) + if err != nil { + t.Error(err) + return err + } + return nil +} + +type concurrentTxStmtQueryTest struct { + db *DB + tx *Tx + stmt *Stmt +} + +func (c *concurrentTxStmtQueryTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } + c.stmt, err = c.tx.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxStmtQueryTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxStmtQueryTest) test(t testing.TB) error { + rows, err := c.stmt.Query() + if err != nil { + t.Errorf("error on query: %v", err) + return err + } + + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentTxStmtExecTest struct { + db *DB + tx *Tx + stmt *Stmt +} + +func (c *concurrentTxStmtExecTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } + c.stmt, err = c.tx.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxStmtExecTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxStmtExecTest) test(t testing.TB) error { + _, err := c.stmt.Exec(3, chrisBirthday) + if err != nil { + t.Errorf("error on exec: %v", err) + return err + } + return nil +} + +type concurrentRandomTest struct { + tests []concurrentTest +} + +func (c *concurrentRandomTest) init(t testing.TB, db *DB) { + c.tests = []concurrentTest{ + new(concurrentDBQueryTest), + new(concurrentDBExecTest), + new(concurrentStmtQueryTest), + new(concurrentStmtExecTest), + new(concurrentTxQueryTest), + new(concurrentTxExecTest), + new(concurrentTxStmtQueryTest), + new(concurrentTxStmtExecTest), + } + for _, ct := range c.tests { + ct.init(t, db) + } +} + +func (c *concurrentRandomTest) finish(t testing.TB) { + for _, ct := range c.tests { + ct.finish(t) + } +} + +func (c *concurrentRandomTest) test(t testing.TB) error { + ct := c.tests[rand.Intn(len(c.tests))] + return ct.test(t) +} + +func doConcurrentTest(t testing.TB, ct concurrentTest) { + maxProcs, numReqs := 1, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + + db := newTestDB(t, "people") + defer closeDB(t, db) + + ct.init(t, db) + defer ct.finish(t) + + var wg sync.WaitGroup + wg.Add(numReqs) + + reqs := make(chan bool) + defer close(reqs) + + for i := 0; i < maxProcs*2; i++ { + go func() { + for range reqs { + err := ct.test(t) + if err != nil { + wg.Done() + continue + } + wg.Done() + } + }() + } + + for i := 0; i < numReqs; i++ { + reqs <- true + } + + wg.Wait() +} + +func manyConcurrentQueries(t testing.TB) { + maxProcs, numReqs := 16, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + + db := newTestDB(t, "people") + defer closeDB(t, db) + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + var wg sync.WaitGroup + wg.Add(numReqs) + + reqs := make(chan bool) + defer close(reqs) + + for i := 0; i < maxProcs*2; i++ { + go func() { + for range reqs { + rows, err := stmt.Query() + if err != nil { + t.Errorf("error on query: %v", err) + wg.Done() + continue + } + + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + + wg.Done() + } + }() + } + + for i := 0; i < numReqs; i++ { + reqs <- true + } + + wg.Wait() +} + +func TestIssue6081(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + drv := db.driver.(*fakeDriver) + drv.mu.Lock() + opens0 := drv.openCount + closes0 := drv.closeCount + drv.mu.Unlock() + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + rowsCloseHook = func(rows *Rows, err *error) { + *err = driver.ErrBadConn + } + defer func() { rowsCloseHook = nil }() + for i := 0; i < 10; i++ { + rows, err := stmt.Query() + if err != nil { + t.Fatal(err) + } + rows.Close() + } + if n := len(stmt.css); n > 1 { + t.Errorf("len(css slice) = %d; want <= 1", n) + } + stmt.Close() + if n := len(stmt.css); n != 0 { + t.Errorf("len(css slice) after Close = %d; want 0", n) + } + + drv.mu.Lock() + opens := drv.openCount - opens0 + closes := drv.closeCount - closes0 + drv.mu.Unlock() + if opens < 9 { + t.Errorf("opens = %d; want >= 9", opens) + } + if closes < 9 { + t.Errorf("closes = %d; want >= 9", closes) + } +} + +func TestConcurrency(t *testing.T) { + doConcurrentTest(t, new(concurrentDBQueryTest)) + doConcurrentTest(t, new(concurrentDBExecTest)) + doConcurrentTest(t, new(concurrentStmtQueryTest)) + doConcurrentTest(t, new(concurrentStmtExecTest)) + doConcurrentTest(t, new(concurrentTxQueryTest)) + doConcurrentTest(t, new(concurrentTxExecTest)) + doConcurrentTest(t, new(concurrentTxStmtQueryTest)) + doConcurrentTest(t, new(concurrentTxStmtExecTest)) + doConcurrentTest(t, new(concurrentRandomTest)) +} + +func TestConnectionLeak(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + // Start by opening defaultMaxIdleConns + rows := make([]*Rows, defaultMaxIdleConns) + // We need to SetMaxOpenConns > MaxIdleConns, so the DB can open + // a new connection and we can fill the idle queue with the released + // connections. + db.SetMaxOpenConns(len(rows) + 1) + for ii := range rows { + r, err := db.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + r.Next() + if err := r.Err(); err != nil { + t.Fatal(err) + } + rows[ii] = r + } + // Now we have defaultMaxIdleConns busy connections. Open + // a new one, but wait until the busy connections are released + // before returning control to DB. + drv := db.driver.(*fakeDriver) + drv.waitCh = make(chan struct{}, 1) + drv.waitingCh = make(chan struct{}, 1) + var wg sync.WaitGroup + wg.Add(1) + go func() { + r, err := db.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + r.Close() + wg.Done() + }() + // Wait until the goroutine we've just created has started waiting. + <-drv.waitingCh + // Now close the busy connections. This provides a connection for + // the blocked goroutine and then fills up the idle queue. + for _, v := range rows { + v.Close() + } + // At this point we give the new connection to DB. This connection is + // now useless, since the idle queue is full and there are no pending + // requests. DB should deal with this situation without leaking the + // connection. + drv.waitCh <- struct{}{} + wg.Wait() +} + +func BenchmarkConcurrentDBExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentDBExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentStmtQuery(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentStmtQueryTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentStmtExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentStmtExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxQuery(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxQueryTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxStmtQuery(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxStmtQueryTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxStmtExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxStmtExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentRandom(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentRandomTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} |