diff options
author | Ondřej Surý <ondrej@sury.org> | 2012-04-06 15:14:11 +0200 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2012-04-06 15:14:11 +0200 |
commit | 505c19580e0f43fe5224431459cacb7c21edd93d (patch) | |
tree | 79e2634c253d60afc0cc0b2f510dc7dcbb48497b /src/pkg/database | |
parent | 1336a7c91e596c423a49d1194ea42d98bca0d958 (diff) | |
download | golang-505c19580e0f43fe5224431459cacb7c21edd93d.tar.gz |
Imported Upstream version 1upstream/1
Diffstat (limited to 'src/pkg/database')
-rw-r--r-- | src/pkg/database/sql/convert.go | 158 | ||||
-rw-r--r-- | src/pkg/database/sql/convert_test.go | 250 | ||||
-rw-r--r-- | src/pkg/database/sql/doc.txt | 46 | ||||
-rw-r--r-- | src/pkg/database/sql/driver/driver.go | 215 | ||||
-rw-r--r-- | src/pkg/database/sql/driver/types.go | 252 | ||||
-rw-r--r-- | src/pkg/database/sql/driver/types_test.go | 65 | ||||
-rw-r--r-- | src/pkg/database/sql/fakedb_test.go | 629 | ||||
-rw-r--r-- | src/pkg/database/sql/sql.go | 1063 | ||||
-rw-r--r-- | src/pkg/database/sql/sql_test.go | 614 |
9 files changed, 3292 insertions, 0 deletions
diff --git a/src/pkg/database/sql/convert.go b/src/pkg/database/sql/convert.go new file mode 100644 index 000000000..bfcb03ccf --- /dev/null +++ b/src/pkg/database/sql/convert.go @@ -0,0 +1,158 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Type conversions for Scan. + +package sql + +import ( + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" +) + +// subsetTypeArgs takes a slice of arguments from callers of the sql +// package and converts them into a slice of the driver package's +// "subset types". +func subsetTypeArgs(args []interface{}) ([]driver.Value, error) { + out := make([]driver.Value, len(args)) + for n, arg := range args { + var err error + out[n], err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { + return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err) + } + } + return out, nil +} + +// convertAssign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func convertAssign(dest, src interface{}) error { + // Common cases, without reflect. Fall through. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + *d = s + return nil + case *[]byte: + *d = []byte(s) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + *d = string(s) + return nil + case *interface{}: + bcopy := make([]byte, len(s)) + copy(bcopy, s) + *d = bcopy + return nil + case *[]byte: + *d = s + return nil + } + case nil: + switch d := dest.(type) { + case *[]byte: + *d = nil + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = fmt.Sprintf("%v", src) + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *interface{}: + *d = src + return nil + } + + if scanner, ok := dest.(Scanner); ok { + return scanner.Scan(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Ptr { + return errors.New("destination not a pointer") + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if dv.Kind() == sv.Kind() { + dv.Set(sv) + return nil + } + + switch dv.Kind() { + case reflect.Ptr: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } else { + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssign(dv.Interface(), src) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + } + + return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) +} + +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + return fmt.Sprintf("%v", src) +} diff --git a/src/pkg/database/sql/convert_test.go b/src/pkg/database/sql/convert_test.go new file mode 100644 index 000000000..9c362d732 --- /dev/null +++ b/src/pkg/database/sql/convert_test.go @@ -0,0 +1,250 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "database/sql/driver" + "fmt" + "reflect" + "testing" + "time" +) + +var someTime = time.Unix(123, 0) +var answer int64 = 42 + +type conversionTest struct { + s, d interface{} // source and destination + + // following are used if they're non-zero + wantint int64 + wantuint uint64 + wantstr string + wantf32 float32 + wantf64 float64 + wanttime time.Time + wantbool bool // used if d is of type *bool + wanterr string + wantiface interface{} + wantptr *int64 // if non-nil, *d's pointed value must be equal to *wantptr + wantnil bool // if true, *d must be *int64(nil) +} + +// Target variables for scanning into. +var ( + scanstr string + scanint int + scanint8 int8 + scanint16 int16 + scanint32 int32 + scanuint8 uint8 + scanuint16 uint16 + scanbool bool + scanf32 float32 + scanf64 float64 + scantime time.Time + scanptr *int64 + scaniface interface{} +) + +var conversionTests = []conversionTest{ + // Exact conversions (destination pointer type matches source type) + {s: "foo", d: &scanstr, wantstr: "foo"}, + {s: 123, d: &scanint, wantint: 123}, + {s: someTime, d: &scantime, wanttime: someTime}, + + // To strings + {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, + {s: 123, d: &scanstr, wantstr: "123"}, + {s: int8(123), d: &scanstr, wantstr: "123"}, + {s: int64(123), d: &scanstr, wantstr: "123"}, + {s: uint8(123), d: &scanstr, wantstr: "123"}, + {s: uint16(123), d: &scanstr, wantstr: "123"}, + {s: uint32(123), d: &scanstr, wantstr: "123"}, + {s: uint64(123), d: &scanstr, wantstr: "123"}, + {s: 1.5, d: &scanstr, wantstr: "1.5"}, + + // Strings to integers + {s: "255", d: &scanuint8, wantuint: 255}, + {s: "256", d: &scanuint8, wanterr: `converting string "256" to a uint8: strconv.ParseUint: parsing "256": value out of range`}, + {s: "256", d: &scanuint16, wantuint: 256}, + {s: "-1", d: &scanint, wantint: -1}, + {s: "foo", d: &scanint, wanterr: `converting string "foo" to a int: strconv.ParseInt: parsing "foo": invalid syntax`}, + + // True bools + {s: true, d: &scanbool, wantbool: true}, + {s: "True", d: &scanbool, wantbool: true}, + {s: "TRUE", d: &scanbool, wantbool: true}, + {s: "1", d: &scanbool, wantbool: true}, + {s: 1, d: &scanbool, wantbool: true}, + {s: int64(1), d: &scanbool, wantbool: true}, + {s: uint16(1), d: &scanbool, wantbool: true}, + + // False bools + {s: false, d: &scanbool, wantbool: false}, + {s: "false", d: &scanbool, wantbool: false}, + {s: "FALSE", d: &scanbool, wantbool: false}, + {s: "0", d: &scanbool, wantbool: false}, + {s: 0, d: &scanbool, wantbool: false}, + {s: int64(0), d: &scanbool, wantbool: false}, + {s: uint16(0), d: &scanbool, wantbool: false}, + + // Not bools + {s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`}, + {s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`}, + + // Floats + {s: float64(1.5), d: &scanf64, wantf64: float64(1.5)}, + {s: int64(1), d: &scanf64, wantf64: float64(1)}, + {s: float64(1.5), d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf64, wantf64: float64(1.5)}, + + // Pointers + {s: interface{}(nil), d: &scanptr, wantnil: true}, + {s: int64(42), d: &scanptr, wantptr: &answer}, + + // To interface{} + {s: float64(1.5), d: &scaniface, wantiface: float64(1.5)}, + {s: int64(1), d: &scaniface, wantiface: int64(1)}, + {s: "str", d: &scaniface, wantiface: "str"}, + {s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")}, + {s: true, d: &scaniface, wantiface: true}, + {s: nil, d: &scaniface}, +} + +func intPtrValue(intptr interface{}) interface{} { + return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int() +} + +func intValue(intptr interface{}) int64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Int() +} + +func uintValue(intptr interface{}) uint64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Uint() +} + +func float64Value(ptr interface{}) float64 { + return *(ptr.(*float64)) +} + +func float32Value(ptr interface{}) float32 { + return *(ptr.(*float32)) +} + +func timeValue(ptr interface{}) time.Time { + return *(ptr.(*time.Time)) +} + +func TestConversions(t *testing.T) { + for n, ct := range conversionTests { + err := convertAssign(ct.d, ct.s) + errstr := "" + if err != nil { + errstr = err.Error() + } + errf := func(format string, args ...interface{}) { + base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d) + t.Errorf(base+format, args...) + } + if errstr != ct.wanterr { + errf("got error %q, want error %q", errstr, ct.wanterr) + } + if ct.wantstr != "" && ct.wantstr != scanstr { + errf("want string %q, got %q", ct.wantstr, scanstr) + } + if ct.wantint != 0 && ct.wantint != intValue(ct.d) { + errf("want int %d, got %d", ct.wantint, intValue(ct.d)) + } + if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) { + errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d)) + } + if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d)) + } + if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d)) + } + if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" { + errf("want bool %v, got %v", ct.wantbool, *bp) + } + if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) { + errf("want time %v, got %v", ct.wanttime, timeValue(ct.d)) + } + if ct.wantnil && *ct.d.(**int64) != nil { + errf("want nil, got %v", intPtrValue(ct.d)) + } + if ct.wantptr != nil { + if *ct.d.(**int64) == nil { + errf("want pointer to %v, got nil", *ct.wantptr) + } else if *ct.wantptr != intPtrValue(ct.d) { + errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d)) + } + } + if ifptr, ok := ct.d.(*interface{}); ok { + if !reflect.DeepEqual(ct.wantiface, scaniface) { + errf("want interface %#v, got %#v", ct.wantiface, scaniface) + continue + } + if srcBytes, ok := ct.s.([]byte); ok { + dstBytes := (*ifptr).([]byte) + if &dstBytes[0] == &srcBytes[0] { + errf("copy into interface{} didn't copy []byte data") + } + } + } + } +} + +func TestNullString(t *testing.T) { + var ns NullString + convertAssign(&ns, []byte("foo")) + if !ns.Valid { + t.Errorf("expecting not null") + } + if ns.String != "foo" { + t.Errorf("expecting foo; got %q", ns.String) + } + convertAssign(&ns, nil) + if ns.Valid { + t.Errorf("expecting null on nil") + } + if ns.String != "" { + t.Errorf("expecting blank on nil; got %q", ns.String) + } +} + +type valueConverterTest struct { + c driver.ValueConverter + in, out interface{} + err string +} + +var valueConverterTests = []valueConverterTest{ + {driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""}, + {driver.DefaultParameterConverter, NullString{"", false}, nil, ""}, +} + +func TestValueConverters(t *testing.T) { + for i, tt := range valueConverterTests { + out, err := tt.c.ConvertValue(tt.in) + goterr := "" + if err != nil { + goterr = err.Error() + } + if goterr != tt.err { + t.Errorf("test %d: %s(%T(%v)) error = %q; want error = %q", + i, tt.c, tt.in, tt.in, goterr, tt.err) + } + if tt.err != "" { + continue + } + if !reflect.DeepEqual(out, tt.out) { + t.Errorf("test %d: %s(%T(%v)) = %v (%T); want %v (%T)", + i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out) + } + } +} diff --git a/src/pkg/database/sql/doc.txt b/src/pkg/database/sql/doc.txt new file mode 100644 index 000000000..fb1659548 --- /dev/null +++ b/src/pkg/database/sql/doc.txt @@ -0,0 +1,46 @@ +Goals of the sql and sql/driver packages: + +* Provide a generic database API for a variety of SQL or SQL-like + databases. There currently exist Go libraries for SQLite, MySQL, + and Postgres, but all with a very different feel, and often + a non-Go-like feel. + +* Feel like Go. + +* Care mostly about the common cases. Common SQL should be portable. + SQL edge cases or db-specific extensions can be detected and + conditionally used by the application. It is a non-goal to care + about every particular db's extension or quirk. + +* Separate out the basic implementation of a database driver + (implementing the sql/driver interfaces) vs the implementation + of all the user-level types and convenience methods. + In a nutshell: + + User Code ---> sql package (concrete types) ---> sql/driver (interfaces) + Database Driver -> sql (to register) + sql/driver (implement interfaces) + +* Make type casting/conversions consistent between all drivers. To + achieve this, most of the conversions are done in the db package, + not in each driver. The drivers then only have to deal with a + smaller set of types. + +* Be flexible with type conversions, but be paranoid about silent + truncation or other loss of precision. + +* Handle concurrency well. Users shouldn't need to care about the + database's per-connection thread safety issues (or lack thereof), + and shouldn't have to maintain their own free pools of connections. + The 'db' package should deal with that bookkeeping as needed. Given + an *sql.DB, it should be possible to share that instance between + multiple goroutines, without any extra synchronization. + +* Push complexity, where necessary, down into the sql+driver packages, + rather than exposing it to users. Said otherwise, the sql package + should expose an ideal database that's not finnicky about how it's + accessed, even if that's not true. + +* Provide optional interfaces in sql/driver for drivers to implement + for special cases or fastpaths. But the only party that knows about + those is the sql package. To user code, some stuff just might start + working or start working slightly faster. diff --git a/src/pkg/database/sql/driver/driver.go b/src/pkg/database/sql/driver/driver.go new file mode 100644 index 000000000..2f5280db8 --- /dev/null +++ b/src/pkg/database/sql/driver/driver.go @@ -0,0 +1,215 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package driver defines interfaces to be implemented by database +// drivers as used by package sql. +// +// Most code should use package sql. +package driver + +import "errors" + +// A driver Value is a value that drivers must be able to handle. +// A Value is either nil or an instance of one of these types: +// +// int64 +// float64 +// bool +// []byte +// string [*] everywhere except from Rows.Next. +// time.Time +type Value interface{} + +// Driver is the interface that must be implemented by a database +// driver. +type Driver interface { + // Open returns a new connection to the database. + // The name is a string in a driver-specific format. + // + // Open may return a cached connection (one previously + // closed), but doing so is unnecessary; the sql package + // maintains a pool of idle connections for efficient re-use. + // + // The returned connection is only used by one goroutine at a + // time. + Open(name string) (Conn, error) +} + +// ErrSkip may be returned by some optional interfaces' methods to +// indicate at runtime that the fast path is unavailable and the sql +// package should continue as if the optional interface was not +// implemented. ErrSkip is only supported where explicitly +// documented. +var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented") + +// ErrBadConn should be returned by a driver to signal to the sql +// package that a driver.Conn is in a bad state (such as the server +// having earlier closed the connection) and the sql package should +// retry on a new connection. +// +// To prevent duplicate operations, ErrBadConn should NOT be returned +// if there's a possibility that the database server might have +// performed the operation. Even if the server sends back an error, +// you shouldn't return ErrBadConn. +var ErrBadConn = errors.New("driver: bad connection") + +// Execer is an optional interface that may be implemented by a Conn. +// +// If a Conn does not implement Execer, the db package's DB.Exec will +// first prepare a query, execute the statement, and then close the +// statement. +// +// Exec may return ErrSkip. +type Execer interface { + Exec(query string, args []Value) (Result, error) +} + +// Conn is a connection to a database. It is not used concurrently +// by multiple goroutines. +// +// Conn is assumed to be stateful. +type Conn interface { + // Prepare returns a prepared statement, bound to this connection. + Prepare(query string) (Stmt, error) + + // Close invalidates and potentially stops any current + // prepared statements and transactions, marking this + // connection as no longer in use. + // + // Because the sql package maintains a free pool of + // connections and only calls Close when there's a surplus of + // idle connections, it shouldn't be necessary for drivers to + // do their own connection caching. + Close() error + + // Begin starts and returns a new transaction. + Begin() (Tx, error) +} + +// Result is the result of a query execution. +type Result interface { + // LastInsertId returns the database's auto-generated ID + // after, for example, an INSERT into a table with primary + // key. + LastInsertId() (int64, error) + + // RowsAffected returns the number of rows affected by the + // query. + RowsAffected() (int64, error) +} + +// Stmt is a prepared statement. It is bound to a Conn and not +// used by multiple goroutines concurrently. +type Stmt interface { + // Close closes the statement. + // + // Closing a statement should not interrupt any outstanding + // query created from that statement. That is, the following + // order of operations is valid: + // + // * create a driver statement + // * call Query on statement, returning Rows + // * close the statement + // * read from Rows + // + // If closing a statement invalidates currently-running + // queries, the final step above will incorrectly fail. + // + // TODO(bradfitz): possibly remove the restriction above, if + // enough driver authors object and find it complicates their + // code too much. The sql package could be smarter about + // refcounting the statement and closing it at the appropriate + // time. + Close() error + + // NumInput returns the number of placeholder parameters. + // + // If NumInput returns >= 0, the sql package will sanity check + // argument counts from callers and return errors to the caller + // before the statement's Exec or Query methods are called. + // + // NumInput may also return -1, if the driver doesn't know + // its number of placeholders. In that case, the sql package + // will not sanity check Exec or Query argument counts. + NumInput() int + + // Exec executes a query that doesn't return rows, such + // as an INSERT or UPDATE. + Exec(args []Value) (Result, error) + + // Exec executes a query that may return rows, such as a + // SELECT. + Query(args []Value) (Rows, error) +} + +// ColumnConverter may be optionally implemented by Stmt if the +// the statement is aware of its own columns' types and can +// convert from any type to a driver Value. +type ColumnConverter interface { + // ColumnConverter returns a ValueConverter for the provided + // column index. If the type of a specific column isn't known + // or shouldn't be handled specially, DefaultValueConverter + // can be returned. + ColumnConverter(idx int) ValueConverter +} + +// Rows is an iterator over an executed query's results. +type Rows interface { + // Columns returns the names of the columns. The number of + // columns of the result is inferred from the length of the + // slice. If a particular column name isn't known, an empty + // string should be returned for that entry. + Columns() []string + + // Close closes the rows iterator. + Close() error + + // Next is called to populate the next row of data into + // the provided slice. The provided slice will be the same + // size as the Columns() are wide. + // + // The dest slice may be populated only with + // a driver Value type, but excluding string. + // All string values must be converted to []byte. + // + // Next should return io.EOF when there are no more rows. + Next(dest []Value) error +} + +// Tx is a transaction. +type Tx interface { + Commit() error + Rollback() error +} + +// RowsAffected implements Result for an INSERT or UPDATE operation +// which mutates a number of rows. +type RowsAffected int64 + +var _ Result = RowsAffected(0) + +func (RowsAffected) LastInsertId() (int64, error) { + return 0, errors.New("no LastInsertId available") +} + +func (v RowsAffected) RowsAffected() (int64, error) { + return int64(v), nil +} + +// ResultNoRows is a pre-defined Result for drivers to return when a DDL +// command (such as a CREATE TABLE) succeeds. It returns an error for both +// LastInsertId and RowsAffected. +var ResultNoRows noRows + +type noRows struct{} + +var _ Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errors.New("no LastInsertId available after DDL statement") +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errors.New("no RowsAffected available after DDL statement") +} diff --git a/src/pkg/database/sql/driver/types.go b/src/pkg/database/sql/driver/types.go new file mode 100644 index 000000000..3305354df --- /dev/null +++ b/src/pkg/database/sql/driver/types.go @@ -0,0 +1,252 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package driver + +import ( + "fmt" + "reflect" + "strconv" + "time" +) + +// ValueConverter is the interface providing the ConvertValue method. +// +// Various implementations of ValueConverter are provided by the +// driver package to provide consistent implementations of conversions +// between drivers. The ValueConverters have several uses: +// +// * converting from the Value types as provided by the sql package +// into a database table's specific column type and making sure it +// fits, such as making sure a particular int64 fits in a +// table's uint16 column. +// +// * converting a value as given from the database into one of the +// driver Value types. +// +// * by the sql package, for converting from a driver's Value type +// to a user's type in a scan. +type ValueConverter interface { + // ConvertValue converts a value to a driver Value. + ConvertValue(v interface{}) (Value, error) +} + +// Valuer is the interface providing the Value method. +// +// Types implementing Valuer interface are able to convert +// themselves to a driver Value. +type Valuer interface { + // Value returns a driver Value. + Value() (Value, error) +} + +// Bool is a ValueConverter that converts input values to bools. +// +// The conversion rules are: +// - booleans are returned unchanged +// - for integer types, +// 1 is true +// 0 is false, +// other integers are an error +// - for strings and []byte, same rules as strconv.ParseBool +// - all other types are an error +var Bool boolType + +type boolType struct{} + +var _ ValueConverter = boolType{} + +func (boolType) String() string { return "Bool" } + +func (boolType) ConvertValue(src interface{}) (Value, error) { + switch s := src.(type) { + case bool: + return s, nil + case string: + b, err := strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s) + } + return b, nil + case []byte: + b, err := strconv.ParseBool(string(s)) + if err != nil { + return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s) + } + return b, nil + } + + sv := reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + iv := sv.Int() + if iv == 1 || iv == 0 { + return iv == 1, nil + } + return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", iv) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + uv := sv.Uint() + if uv == 1 || uv == 0 { + return uv == 1, nil + } + return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", uv) + } + + return nil, fmt.Errorf("sql/driver: couldn't convert %v (%T) into type bool", src, src) +} + +// Int32 is a ValueConverter that converts input values to int64, +// respecting the limits of an int32 value. +var Int32 int32Type + +type int32Type struct{} + +var _ ValueConverter = int32Type{} + +func (int32Type) ConvertValue(v interface{}) (Value, error) { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i64 := rv.Int() + if i64 > (1<<31)-1 || i64 < -(1<<31) { + return nil, fmt.Errorf("sql/driver: value %d overflows int32", v) + } + return i64, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u64 := rv.Uint() + if u64 > (1<<31)-1 { + return nil, fmt.Errorf("sql/driver: value %d overflows int32", v) + } + return int64(u64), nil + case reflect.String: + i, err := strconv.Atoi(rv.String()) + if err != nil { + return nil, fmt.Errorf("sql/driver: value %q can't be converted to int32", v) + } + return int64(i), nil + } + return nil, fmt.Errorf("sql/driver: unsupported value %v (type %T) converting to int32", v, v) +} + +// String is a ValueConverter that converts its input to a string. +// If the value is already a string or []byte, it's unchanged. +// If the value is of another type, conversion to string is done +// with fmt.Sprintf("%v", v). +var String stringType + +type stringType struct{} + +func (stringType) ConvertValue(v interface{}) (Value, error) { + switch v.(type) { + case string, []byte: + return v, nil + } + return fmt.Sprintf("%v", v), nil +} + +// Null is a type that implements ValueConverter by allowing nil +// values but otherwise delegating to another ValueConverter. +type Null struct { + Converter ValueConverter +} + +func (n Null) ConvertValue(v interface{}) (Value, error) { + if v == nil { + return nil, nil + } + return n.Converter.ConvertValue(v) +} + +// NotNull is a type that implements ValueConverter by disallowing nil +// values but otherwise delegating to another ValueConverter. +type NotNull struct { + Converter ValueConverter +} + +func (n NotNull) ConvertValue(v interface{}) (Value, error) { + if v == nil { + return nil, fmt.Errorf("nil value not allowed") + } + return n.Converter.ConvertValue(v) +} + +// IsValue reports whether v is a valid Value parameter type. +// Unlike IsScanValue, IsValue permits the string type. +func IsValue(v interface{}) bool { + if IsScanValue(v) { + return true + } + if _, ok := v.(string); ok { + return true + } + return false +} + +// IsScanValue reports whether v is a valid Value scan type. +// Unlike IsValue, IsScanValue does not permit the string type. +func IsScanValue(v interface{}) bool { + if v == nil { + return true + } + switch v.(type) { + case int64, float64, []byte, bool, time.Time: + return true + } + return false +} + +// DefaultParameterConverter is the default implementation of +// ValueConverter that's used when a Stmt doesn't implement +// ColumnConverter. +// +// DefaultParameterConverter returns the given value directly if +// IsValue(value). Otherwise integer type are converted to +// int64, floats to float64, and strings to []byte. Other types are +// an error. +var DefaultParameterConverter defaultConverter + +type defaultConverter struct{} + +var _ ValueConverter = defaultConverter{} + +func (defaultConverter) ConvertValue(v interface{}) (Value, error) { + if IsValue(v) { + return v, nil + } + + if svi, ok := v.(Valuer); ok { + sv, err := svi.Value() + if err != nil { + return nil, err + } + if !IsValue(sv) { + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + } + return sv, nil + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Ptr: + // indirect pointers + if rv.IsNil() { + return nil, nil + } else { + return defaultConverter{}.ConvertValue(rv.Elem().Interface()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return int64(rv.Uint()), nil + case reflect.Uint64: + u64 := rv.Uint() + if u64 >= 1<<63 { + return nil, fmt.Errorf("uint64 values with high bit set are not supported") + } + return int64(u64), nil + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + } + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) +} diff --git a/src/pkg/database/sql/driver/types_test.go b/src/pkg/database/sql/driver/types_test.go new file mode 100644 index 000000000..ab82bca71 --- /dev/null +++ b/src/pkg/database/sql/driver/types_test.go @@ -0,0 +1,65 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package driver + +import ( + "reflect" + "testing" + "time" +) + +type valueConverterTest struct { + c ValueConverter + in interface{} + out interface{} + err string +} + +var now = time.Now() +var answer int64 = 42 + +var valueConverterTests = []valueConverterTest{ + {Bool, "true", true, ""}, + {Bool, "True", true, ""}, + {Bool, []byte("t"), true, ""}, + {Bool, true, true, ""}, + {Bool, "1", true, ""}, + {Bool, 1, true, ""}, + {Bool, int64(1), true, ""}, + {Bool, uint16(1), true, ""}, + {Bool, "false", false, ""}, + {Bool, false, false, ""}, + {Bool, "0", false, ""}, + {Bool, 0, false, ""}, + {Bool, int64(0), false, ""}, + {Bool, uint16(0), false, ""}, + {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"}, + {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"}, + {DefaultParameterConverter, now, now, ""}, + {DefaultParameterConverter, (*int64)(nil), nil, ""}, + {DefaultParameterConverter, &answer, answer, ""}, + {DefaultParameterConverter, &now, now, ""}, +} + +func TestValueConverters(t *testing.T) { + for i, tt := range valueConverterTests { + out, err := tt.c.ConvertValue(tt.in) + goterr := "" + if err != nil { + goterr = err.Error() + } + if goterr != tt.err { + t.Errorf("test %d: %s(%T(%v)) error = %q; want error = %q", + i, tt.c, tt.in, tt.in, goterr, tt.err) + } + if tt.err != "" { + continue + } + if !reflect.DeepEqual(out, tt.out) { + t.Errorf("test %d: %s(%T(%v)) = %v (%T); want %v (%T)", + i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out) + } + } +} diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go new file mode 100644 index 000000000..184e7756c --- /dev/null +++ b/src/pkg/database/sql/fakedb_test.go @@ -0,0 +1,629 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + "log" + "strconv" + "strings" + "sync" + "time" +) + +var _ = log.Printf + +// fakeDriver is a fake database that implements Go's driver.Driver +// interface, just for testing. +// +// It speaks a query language that's semantically similar to but +// syntantically different and simpler than SQL. The syntax is as +// follows: +// +// WIPE +// CREATE|<tablename>|<col>=<type>,<col>=<type>,... +// where types are: "string", [u]int{8,16,32,64}, "bool" +// INSERT|<tablename>|col=val,col2=val2,col3=? +// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? +// +// When opening a a fakeDriver's database, it starts empty with no +// tables. All tables and data are stored in memory only. +type fakeDriver struct { + mu sync.Mutex + openCount int + dbs map[string]*fakeDB +} + +type fakeDB struct { + name string + + mu sync.Mutex + free []*fakeConn + tables map[string]*table +} + +type table struct { + mu sync.Mutex + colname []string + coltype []string + rows []*row +} + +func (t *table) columnIndex(name string) int { + for n, nname := range t.colname { + if name == nname { + return n + } + } + return -1 +} + +type row struct { + cols []interface{} // must be same size as its table colname + coltype +} + +func (r *row) clone() *row { + nrow := &row{cols: make([]interface{}, len(r.cols))} + copy(nrow.cols, r.cols) + return nrow +} + +type fakeConn struct { + db *fakeDB // where to return ourselves to + + currTx *fakeTx + + // Stats for tests: + mu sync.Mutex + stmtsMade int + stmtsClosed int + numPrepare int +} + +func (c *fakeConn) incrStat(v *int) { + c.mu.Lock() + *v++ + c.mu.Unlock() +} + +type fakeTx struct { + c *fakeConn +} + +type fakeStmt struct { + c *fakeConn + q string // just for debugging + + cmd string + table string + + closed bool + + colName []string // used by CREATE, INSERT, SELECT (selected columns) + colType []string // used by CREATE + colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) + placeholders int // used by INSERT/SELECT: number of ? params + + whereCol []string // used by SELECT (all placeholders) + + placeholderConverter []driver.ValueConverter // used by INSERT +} + +var fdriver driver.Driver = &fakeDriver{} + +func init() { + Register("test", fdriver) +} + +// Supports dsn forms: +// <dbname> +// <dbname>;<opts> (no currently supported options) +func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { + parts := strings.Split(dsn, ";") + if len(parts) < 1 { + return nil, errors.New("fakedb: no database name") + } + name := parts[0] + + db := d.getDB(name) + + d.mu.Lock() + d.openCount++ + d.mu.Unlock() + return &fakeConn{db: db}, nil +} + +func (d *fakeDriver) getDB(name string) *fakeDB { + d.mu.Lock() + defer d.mu.Unlock() + if d.dbs == nil { + d.dbs = make(map[string]*fakeDB) + } + db, ok := d.dbs[name] + if !ok { + db = &fakeDB{name: name} + d.dbs[name] = db + } + return db +} + +func (db *fakeDB) wipe() { + db.mu.Lock() + defer db.mu.Unlock() + db.tables = nil +} + +func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { + db.mu.Lock() + defer db.mu.Unlock() + if db.tables == nil { + db.tables = make(map[string]*table) + } + if _, exist := db.tables[name]; exist { + return fmt.Errorf("table %q already exists", name) + } + if len(columnNames) != len(columnTypes) { + return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", + name, len(columnNames), len(columnTypes)) + } + db.tables[name] = &table{colname: columnNames, coltype: columnTypes} + return nil +} + +// must be called with db.mu lock held +func (db *fakeDB) table(table string) (*table, bool) { + if db.tables == nil { + return nil, false + } + t, ok := db.tables[table] + return t, ok +} + +func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { + db.mu.Lock() + defer db.mu.Unlock() + t, ok := db.table(table) + if !ok { + return + } + for n, cname := range t.colname { + if cname == column { + return t.coltype[n], true + } + } + return "", false +} + +func (c *fakeConn) Begin() (driver.Tx, error) { + if c.currTx != nil { + return nil, errors.New("already in a transaction") + } + c.currTx = &fakeTx{c: c} + return c.currTx, nil +} + +func (c *fakeConn) Close() error { + if c.currTx != nil { + return errors.New("can't close fakeConn; in a Transaction") + } + if c.db == nil { + return errors.New("can't close fakeConn; already closed") + } + if c.stmtsMade > c.stmtsClosed { + return errors.New("can't close; dangling statement(s)") + } + c.db = nil + return nil +} + +func checkSubsetTypes(args []driver.Value) error { + for n, arg := range args { + switch arg.(type) { + case int64, float64, bool, nil, []byte, string, time.Time: + default: + return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) + } + } + return nil +} + +func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args of of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + +func errf(msg string, args ...interface{}) error { + return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) +} + +// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? +// (note that where where columns must always contain ? marks, +// just a limitation for fakedb) +func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { + if len(parts) != 3 { + stmt.Close() + return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) + } + stmt.table = parts[0] + stmt.colName = strings.Split(parts[1], ",") + for n, colspec := range strings.Split(parts[2], ",") { + if colspec == "" { + continue + } + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + _, ok := c.db.columnType(stmt.table, column) + if !ok { + stmt.Close() + return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) + } + if value != "?" { + stmt.Close() + return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", + stmt.table, column) + } + stmt.whereCol = append(stmt.whereCol, column) + stmt.placeholders++ + } + return stmt, nil +} + +// parts are table|col=type,col2=type2 +func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameType := strings.Split(colspec, "=") + if len(nameType) != 2 { + stmt.Close() + return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + stmt.colName = append(stmt.colName, nameType[0]) + stmt.colType = append(stmt.colType, nameType[1]) + } + return stmt, nil +} + +// parts are table|col=?,col2=val +func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + ctype, ok := c.db.columnType(stmt.table, column) + if !ok { + stmt.Close() + return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) + } + stmt.colName = append(stmt.colName, column) + + if value != "?" { + var subsetVal interface{} + // Convert to driver subset type + switch ctype { + case "string": + subsetVal = []byte(value) + case "blob": + subsetVal = []byte(value) + case "int32": + i, err := strconv.Atoi(value) + if err != nil { + stmt.Close() + return nil, errf("invalid conversion to int32 from %q", value) + } + subsetVal = int64(i) // int64 is a subset type, but not int32 + default: + stmt.Close() + return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) + } + stmt.colValue = append(stmt.colValue, subsetVal) + } else { + stmt.placeholders++ + stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) + stmt.colValue = append(stmt.colValue, "?") + } + } + return stmt, nil +} + +func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { + c.numPrepare++ + if c.db == nil { + panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) + } + parts := strings.Split(query, "|") + if len(parts) < 1 { + return nil, errf("empty query") + } + cmd := parts[0] + parts = parts[1:] + stmt := &fakeStmt{q: query, c: c, cmd: cmd} + c.incrStat(&c.stmtsMade) + switch cmd { + case "WIPE": + // Nothing + case "SELECT": + return c.prepareSelect(stmt, parts) + case "CREATE": + return c.prepareCreate(stmt, parts) + case "INSERT": + return c.prepareInsert(stmt, parts) + default: + stmt.Close() + return nil, errf("unsupported command type %q", cmd) + } + return stmt, nil +} + +func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { + return s.placeholderConverter[idx] +} + +func (s *fakeStmt) Close() error { + if !s.closed { + s.c.incrStat(&s.c.stmtsClosed) + s.closed = true + } + return nil +} + +var errClosed = errors.New("fakedb: statement has been closed") + +func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { + if s.closed { + return nil, errClosed + } + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + + db := s.c.db + switch s.cmd { + case "WIPE": + db.wipe() + return driver.ResultNoRows, nil + case "CREATE": + if err := db.createTable(s.table, s.colName, s.colType); err != nil { + return nil, err + } + return driver.ResultNoRows, nil + case "INSERT": + return s.execInsert(args) + } + fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) + return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) +} + +func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) { + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + + t.mu.Lock() + defer t.mu.Unlock() + + cols := make([]interface{}, len(t.colname)) + argPos := 0 + for n, colname := range s.colName { + colidx := t.columnIndex(colname) + if colidx == -1 { + return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) + } + var val interface{} + if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { + val = args[argPos] + argPos++ + } else { + val = s.colValue[n] + } + cols[colidx] = val + } + + t.rows = append(t.rows, &row{cols: cols}) + return driver.RowsAffected(1), nil +} + +func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { + if s.closed { + return nil, errClosed + } + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + t.mu.Lock() + defer t.mu.Unlock() + + colIdx := make(map[string]int) // select column name -> column index in table + for _, name := range s.colName { + idx := t.columnIndex(name) + if idx == -1 { + return nil, fmt.Errorf("fakedb: unknown column name %q", name) + } + colIdx[name] = idx + } + + mrows := []*row{} +rows: + for _, trow := range t.rows { + // Process the where clause, skipping non-match rows. This is lazy + // and just uses fmt.Sprintf("%v") to test equality. Good enough + // for test code. + for widx, wcol := range s.whereCol { + idx := t.columnIndex(wcol) + if idx == -1 { + return nil, fmt.Errorf("db: invalid where clause column %q", wcol) + } + tcol := trow.cols[idx] + if bs, ok := tcol.([]byte); ok { + // lazy hack to avoid sprintf %v on a []byte + tcol = string(bs) + } + if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { + continue rows + } + } + mrow := &row{cols: make([]interface{}, len(s.colName))} + for seli, name := range s.colName { + mrow.cols[seli] = trow.cols[colIdx[name]] + } + mrows = append(mrows, mrow) + } + + cursor := &rowsCursor{ + pos: -1, + rows: mrows, + cols: s.colName, + } + return cursor, nil +} + +func (s *fakeStmt) NumInput() int { + return s.placeholders +} + +func (tx *fakeTx) Commit() error { + tx.c.currTx = nil + return nil +} + +func (tx *fakeTx) Rollback() error { + tx.c.currTx = nil + return nil +} + +type rowsCursor struct { + cols []string + pos int + rows []*row + closed bool + + // a clone of slices to give out to clients, indexed by the + // the original slice's first byte address. we clone them + // just so we're able to corrupt them on close. + bytesClone map[*byte][]byte +} + +func (rc *rowsCursor) Close() error { + if !rc.closed { + for _, bs := range rc.bytesClone { + bs[0] = 255 // first byte corrupted + } + } + rc.closed = true + return nil +} + +func (rc *rowsCursor) Columns() []string { + return rc.cols +} + +func (rc *rowsCursor) Next(dest []driver.Value) error { + if rc.closed { + return errors.New("fakedb: cursor is closed") + } + rc.pos++ + if rc.pos >= len(rc.rows) { + return io.EOF // per interface spec + } + for i, v := range rc.rows[rc.pos].cols { + // TODO(bradfitz): convert to subset types? naah, I + // think the subset types should only be input to + // driver, but the sql package should be able to handle + // a wider range of types coming out of drivers. all + // for ease of drivers, and to prevent drivers from + // messing up conversions or doing them differently. + dest[i] = v + + if bs, ok := v.([]byte); ok { + if rc.bytesClone == nil { + rc.bytesClone = make(map[*byte][]byte) + } + clone, ok := rc.bytesClone[&bs[0]] + if !ok { + clone = make([]byte, len(bs)) + copy(clone, bs) + rc.bytesClone[&bs[0]] = clone + } + dest[i] = clone + } + } + return nil +} + +func converterForType(typ string) driver.ValueConverter { + switch typ { + case "bool": + return driver.Bool + case "nullbool": + return driver.Null{Converter: driver.Bool} + case "int32": + return driver.Int32 + case "string": + return driver.NotNull{Converter: driver.String} + case "nullstring": + return driver.Null{Converter: driver.String} + case "int64": + // TODO(coopernurse): add type-specific converter + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nullint64": + // TODO(coopernurse): add type-specific converter + return driver.Null{Converter: driver.DefaultParameterConverter} + case "float64": + // TODO(coopernurse): add type-specific converter + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nullfloat64": + // TODO(coopernurse): add type-specific converter + return driver.Null{Converter: driver.DefaultParameterConverter} + case "datetime": + return driver.DefaultParameterConverter + } + panic("invalid fakedb column type of " + typ) +} diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go new file mode 100644 index 000000000..51a357b37 --- /dev/null +++ b/src/pkg/database/sql/sql.go @@ -0,0 +1,1063 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sql provides a generic interface around SQL (or SQL-like) +// databases. +package sql + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + "sync" +) + +var drivers = make(map[string]driver.Driver) + +// Register makes a database driver available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, driver driver.Driver) { + if driver == nil { + panic("sql: Register driver is nil") + } + if _, dup := drivers[name]; dup { + panic("sql: Register called twice for driver " + name) + } + drivers[name] = driver +} + +// RawBytes is a byte slice that holds a reference to memory owned by +// the database itself. After a Scan into a RawBytes, the slice is only +// valid until the next call to Next, Scan, or Close. +type RawBytes []byte + +// NullString represents a string that may be null. +// NullString implements the Scanner interface so +// it can be used as a scan destination: +// +// var s NullString +// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s) +// ... +// if s.Valid { +// // use s.String +// } else { +// // NULL value +// } +// +type NullString struct { + String string + Valid bool // Valid is true if String is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullString) Scan(value interface{}) error { + if value == nil { + ns.String, ns.Valid = "", false + return nil + } + ns.Valid = true + return convertAssign(&ns.String, value) +} + +// Value implements the driver Valuer interface. +func (ns NullString) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return ns.String, nil +} + +// NullInt64 represents an int64 that may be null. +// NullInt64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullInt64 struct { + Int64 int64 + Valid bool // Valid is true if Int64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt64) Scan(value interface{}) error { + if value == nil { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Int64, value) +} + +// Value implements the driver Valuer interface. +func (n NullInt64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Int64, nil +} + +// NullFloat64 represents a float64 that may be null. +// NullFloat64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullFloat64 struct { + Float64 float64 + Valid bool // Valid is true if Float64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullFloat64) Scan(value interface{}) error { + if value == nil { + n.Float64, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Float64, value) +} + +// Value implements the driver Valuer interface. +func (n NullFloat64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Float64, nil +} + +// NullBool represents a bool that may be null. +// NullBool implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullBool struct { + Bool bool + Valid bool // Valid is true if Bool is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullBool) Scan(value interface{}) error { + if value == nil { + n.Bool, n.Valid = false, false + return nil + } + n.Valid = true + return convertAssign(&n.Bool, value) +} + +// Value implements the driver Valuer interface. +func (n NullBool) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Bool, nil +} + +// Scanner is an interface used by Scan. +type Scanner interface { + // Scan assigns a value from a database driver. + // + // The src value will be of one of the following restricted + // set of types: + // + // int64 + // float64 + // bool + // []byte + // string + // time.Time + // nil - for NULL values + // + // An error should be returned if the value can not be stored + // without loss of information. + Scan(src interface{}) error +} + +// ErrNoRows is returned by Scan when QueryRow doesn't return a +// row. In such a case, QueryRow returns a placeholder *Row value that +// defers this error until a Scan. +var ErrNoRows = errors.New("sql: no rows in result set") + +// DB is a database handle. It's safe for concurrent use by multiple +// goroutines. +// +// If the underlying database driver has the concept of a connection +// and per-connection session state, the sql package manages creating +// and freeing connections automatically, including maintaining a free +// pool of idle connections. If observing session state is required, +// either do not share a *DB between multiple concurrent goroutines or +// create and observe all state only within a transaction. Once +// DB.Open is called, the returned Tx is bound to a single isolated +// connection. Once Tx.Commit or Tx.Rollback is called, that +// connection is returned to DB's idle connection pool. +type DB struct { + driver driver.Driver + dsn string + + mu sync.Mutex // protects freeConn and closed + freeConn []driver.Conn + closed bool +} + +// Open opens a database specified by its database driver name and a +// driver-specific data source name, usually consisting of at least a +// database name and connection information. +// +// Most users will open a database via a driver-specific connection +// helper function that returns a *DB. +func Open(driverName, dataSourceName string) (*DB, error) { + driver, ok := drivers[driverName] + if !ok { + return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) + } + return &DB{driver: driver, dsn: dataSourceName}, nil +} + +// Close closes the database, releasing any open resources. +func (db *DB) Close() error { + db.mu.Lock() + defer db.mu.Unlock() + var err error + for _, c := range db.freeConn { + err1 := c.Close() + if err1 != nil { + err = err1 + } + } + db.freeConn = nil + db.closed = true + return err +} + +func (db *DB) maxIdleConns() int { + const defaultMaxIdleConns = 2 + // TODO(bradfitz): ask driver, if supported, for its default preference + // TODO(bradfitz): let users override? + return defaultMaxIdleConns +} + +// conn returns a newly-opened or cached driver.Conn +func (db *DB) conn() (driver.Conn, error) { + db.mu.Lock() + if db.closed { + db.mu.Unlock() + return nil, errors.New("sql: database is closed") + } + if n := len(db.freeConn); n > 0 { + conn := db.freeConn[n-1] + db.freeConn = db.freeConn[:n-1] + db.mu.Unlock() + return conn, nil + } + db.mu.Unlock() + return db.driver.Open(db.dsn) +} + +func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { + db.mu.Lock() + defer db.mu.Unlock() + for i, conn := range db.freeConn { + if conn != wanted { + continue + } + db.freeConn[i] = db.freeConn[len(db.freeConn)-1] + db.freeConn = db.freeConn[:len(db.freeConn)-1] + return wanted, true + } + return nil, false +} + +// putConnHook is a hook for testing. +var putConnHook func(*DB, driver.Conn) + +// putConn adds a connection to the db's free pool. +// err is optionally the last error that occured on this connection. +func (db *DB) putConn(c driver.Conn, err error) { + if err == driver.ErrBadConn { + // Don't reuse bad connections. + return + } + db.mu.Lock() + if putConnHook != nil { + putConnHook(db, c) + } + if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() { + db.freeConn = append(db.freeConn, c) + db.mu.Unlock() + return + } + // TODO: check to see if we need this Conn for any prepared + // statements which are still active? + db.mu.Unlock() + c.Close() +} + +// Prepare creates a prepared statement for later execution. +func (db *DB) Prepare(query string) (*Stmt, error) { + var stmt *Stmt + var err error + for i := 0; i < 10; i++ { + stmt, err = db.prepare(query) + if err != driver.ErrBadConn { + break + } + } + return stmt, err +} + +func (db *DB) prepare(query string) (stmt *Stmt, err error) { + // TODO: check if db.driver supports an optional + // driver.Preparer interface and call that instead, if so, + // otherwise we make a prepared statement that's bound + // to a connection, and to execute this prepared statement + // we either need to use this connection (if it's free), else + // get a new connection + re-prepare + execute on that one. + ci, err := db.conn() + if err != nil { + return nil, err + } + defer db.putConn(ci, err) + si, err := ci.Prepare(query) + if err != nil { + return nil, err + } + stmt = &Stmt{ + db: db, + query: query, + css: []connStmt{{ci, si}}, + } + return stmt, nil +} + +// Exec executes a query without returning any rows. +func (db *DB) Exec(query string, args ...interface{}) (Result, error) { + sargs, err := subsetTypeArgs(args) + var res Result + for i := 0; i < 10; i++ { + res, err = db.exec(query, sargs) + if err != driver.ErrBadConn { + break + } + } + return res, err +} + +func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) { + ci, err := db.conn() + if err != nil { + return nil, err + } + defer db.putConn(ci, err) + + if execer, ok := ci.(driver.Execer); ok { + resi, err := execer.Exec(query, sargs) + if err != driver.ErrSkip { + if err != nil { + return nil, err + } + return result{resi}, nil + } + } + + sti, err := ci.Prepare(query) + if err != nil { + return nil, err + } + defer sti.Close() + + resi, err := sti.Exec(sargs) + if err != nil { + return nil, err + } + return result{resi}, nil +} + +// Query executes a query that returns rows, typically a SELECT. +func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { + stmt, err := db.Prepare(query) + if err != nil { + return nil, err + } + rows, err := stmt.Query(args...) + if err != nil { + stmt.Close() + return nil, err + } + rows.closeStmt = stmt + return rows, nil +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always return a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (db *DB) QueryRow(query string, args ...interface{}) *Row { + rows, err := db.Query(query, args...) + return &Row{rows: rows, err: err} +} + +// Begin starts a transaction. The isolation level is dependent on +// the driver. +func (db *DB) Begin() (*Tx, error) { + var tx *Tx + var err error + for i := 0; i < 10; i++ { + tx, err = db.begin() + if err != driver.ErrBadConn { + break + } + } + return tx, err +} + +func (db *DB) begin() (tx *Tx, err error) { + ci, err := db.conn() + if err != nil { + return nil, err + } + txi, err := ci.Begin() + if err != nil { + db.putConn(ci, err) + return nil, fmt.Errorf("sql: failed to Begin transaction: %v", err) + } + return &Tx{ + db: db, + ci: ci, + txi: txi, + }, nil +} + +// Driver returns the database's underlying driver. +func (db *DB) Driver() driver.Driver { + return db.driver +} + +// Tx is an in-progress database transaction. +// +// A transaction must end with a call to Commit or Rollback. +// +// After a call to Commit or Rollback, all operations on the +// transaction fail with ErrTxDone. +type Tx struct { + db *DB + + // ci is owned exclusively until Commit or Rollback, at which point + // it's returned with putConn. + ci driver.Conn + txi driver.Tx + + // cimu is held while somebody is using ci (between grabConn + // and releaseConn) + cimu sync.Mutex + + // done transitions from false to true exactly once, on Commit + // or Rollback. once done, all operations fail with + // ErrTxDone. + done bool +} + +var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") + +func (tx *Tx) close() { + if tx.done { + panic("double close") // internal error + } + tx.done = true + tx.db.putConn(tx.ci, nil) + tx.ci = nil + tx.txi = nil +} + +func (tx *Tx) grabConn() (driver.Conn, error) { + if tx.done { + return nil, ErrTxDone + } + tx.cimu.Lock() + return tx.ci, nil +} + +func (tx *Tx) releaseConn() { + tx.cimu.Unlock() +} + +// Commit commits the transaction. +func (tx *Tx) Commit() error { + if tx.done { + return ErrTxDone + } + defer tx.close() + return tx.txi.Commit() +} + +// Rollback aborts the transaction. +func (tx *Tx) Rollback() error { + if tx.done { + return ErrTxDone + } + defer tx.close() + return tx.txi.Rollback() +} + +// Prepare creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and can no longer +// be used once the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +func (tx *Tx) Prepare(query string) (*Stmt, error) { + // TODO(bradfitz): We could be more efficient here and either + // provide a method to take an existing Stmt (created on + // perhaps a different Conn), and re-create it on this Conn if + // necessary. Or, better: keep a map in DB of query string to + // Stmts, and have Stmt.Execute do the right thing and + // re-prepare if the Conn in use doesn't have that prepared + // statement. But we'll want to avoid caching the statement + // in the case where we only call conn.Prepare implicitly + // (such as in db.Exec or tx.Exec), but the caller package + // can't be holding a reference to the returned statement. + // Perhaps just looking at the reference count (by noting + // Stmt.Close) would be enough. We might also want a finalizer + // on Stmt to drop the reference count. + ci, err := tx.grabConn() + if err != nil { + return nil, err + } + defer tx.releaseConn() + + si, err := ci.Prepare(query) + if err != nil { + return nil, err + } + + stmt := &Stmt{ + db: tx.db, + tx: tx, + txsi: si, + query: query, + } + return stmt, nil +} + +// Stmt returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + // TODO(bradfitz): optimize this. Currently this re-prepares + // each time. This is fine for now to illustrate the API but + // we should really cache already-prepared statements + // per-Conn. See also the big comment in Tx.Prepare. + + if tx.db != stmt.db { + return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} + } + ci, err := tx.grabConn() + if err != nil { + return &Stmt{stickyErr: err} + } + defer tx.releaseConn() + si, err := ci.Prepare(stmt.query) + return &Stmt{ + db: tx.db, + tx: tx, + txsi: si, + query: stmt.query, + stickyErr: err, + } +} + +// Exec executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { + ci, err := tx.grabConn() + if err != nil { + return nil, err + } + defer tx.releaseConn() + + sargs, err := subsetTypeArgs(args) + if err != nil { + return nil, err + } + + if execer, ok := ci.(driver.Execer); ok { + resi, err := execer.Exec(query, sargs) + if err == nil { + return result{resi}, nil + } + if err != driver.ErrSkip { + return nil, err + } + } + + sti, err := ci.Prepare(query) + if err != nil { + return nil, err + } + defer sti.Close() + + resi, err := sti.Exec(sargs) + if err != nil { + return nil, err + } + return result{resi}, nil +} + +// Query executes a query that returns rows, typically a SELECT. +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { + if tx.done { + return nil, ErrTxDone + } + stmt, err := tx.Prepare(query) + if err != nil { + return nil, err + } + rows, err := stmt.Query(args...) + if err != nil { + stmt.Close() + return nil, err + } + rows.closeStmt = stmt + return rows, err +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always return a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { + rows, err := tx.Query(query, args...) + return &Row{rows: rows, err: err} +} + +// connStmt is a prepared statement on a particular connection. +type connStmt struct { + ci driver.Conn + si driver.Stmt +} + +// Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines. +type Stmt struct { + // Immutable: + db *DB // where we came from + query string // that created the Stmt + stickyErr error // if non-nil, this error is returned for all operations + + // If in a transaction, else both nil: + tx *Tx + txsi driver.Stmt + + mu sync.Mutex // protects the rest of the fields + closed bool + + // css is a list of underlying driver statement interfaces + // that are valid on particular connections. This is only + // used if tx == nil and one is found that has idle + // connections. If tx != nil, txsi is always used. + css []connStmt +} + +// Exec executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) Exec(args ...interface{}) (Result, error) { + _, releaseConn, si, err := s.connStmt() + if err != nil { + return nil, err + } + defer releaseConn(nil) + + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want := si.NumInput(); want != -1 && len(args) != want { + return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) + } + + sargs := make([]driver.Value, len(args)) + + // Convert args to subset types. + if cc, ok := si.(driver.ColumnConverter); ok { + for n, arg := range args { + // First, see if the value itself knows how to convert + // itself to a driver type. For example, a NullString + // struct changing into a string or nil. + if svi, ok := arg.(driver.Valuer); ok { + sv, err := svi.Value() + if err != nil { + return nil, fmt.Errorf("sql: argument index %d from Value: %v", n, err) + } + if !driver.IsValue(sv) { + return nil, fmt.Errorf("sql: argument index %d: non-subset type %T returned from Value", n, sv) + } + arg = sv + } + + // Second, ask the column to sanity check itself. For + // example, drivers might use this to make sure that + // an int64 values being inserted into a 16-bit + // integer field is in range (before getting + // truncated), or that a nil can't go into a NOT NULL + // column before going across the network to get the + // same error. + sargs[n], err = cc.ColumnConverter(n).ConvertValue(arg) + if err != nil { + return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err) + } + if !driver.IsValue(sargs[n]) { + return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T", + arg, sargs[n]) + } + } + } else { + for n, arg := range args { + sargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { + return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err) + } + } + } + + resi, err := si.Exec(sargs) + if err != nil { + return nil, err + } + return result{resi}, nil +} + +// connStmt returns a free driver connection on which to execute the +// statement, a function to call to release the connection, and a +// statement bound to that connection. +func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.Stmt, err error) { + if err = s.stickyErr; err != nil { + return + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + err = errors.New("sql: statement is closed") + return + } + + // In a transaction, we always use the connection that the + // transaction was created on. + if s.tx != nil { + s.mu.Unlock() + ci, err = s.tx.grabConn() // blocks, waiting for the connection. + if err != nil { + return + } + releaseConn = func(error) { s.tx.releaseConn() } + return ci, releaseConn, s.txsi, nil + } + + var cs connStmt + match := false + for _, v := range s.css { + // TODO(bradfitz): lazily clean up entries in this + // list with dead conns while enumerating + if _, match = s.db.connIfFree(v.ci); match { + cs = v + break + } + } + s.mu.Unlock() + + // Make a new conn if all are busy. + // TODO(bradfitz): or wait for one? make configurable later? + if !match { + for i := 0; ; i++ { + ci, err := s.db.conn() + if err != nil { + return nil, nil, nil, err + } + si, err := ci.Prepare(s.query) + if err == driver.ErrBadConn && i < 10 { + continue + } + if err != nil { + return nil, nil, nil, err + } + s.mu.Lock() + cs = connStmt{ci, si} + s.css = append(s.css, cs) + s.mu.Unlock() + break + } + } + + conn := cs.ci + releaseConn = func(err error) { s.db.putConn(conn, err) } + return conn, releaseConn, cs.si, nil +} + +// Query executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) Query(args ...interface{}) (*Rows, error) { + ci, releaseConn, si, err := s.connStmt() + if err != nil { + return nil, err + } + + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want := si.NumInput(); want != -1 && len(args) != want { + return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", si.NumInput(), len(args)) + } + sargs, err := subsetTypeArgs(args) + if err != nil { + return nil, err + } + rowsi, err := si.Query(sargs) + if err != nil { + releaseConn(err) + return nil, err + } + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + db: s.db, + ci: ci, + releaseConn: releaseConn, + rowsi: rowsi, + } + return rows, nil +} + +// QueryRow executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// Example usage: +// +// var name string +// err := nameByUseridStmt.QueryRow(id).Scan(&name) +func (s *Stmt) QueryRow(args ...interface{}) *Row { + rows, err := s.Query(args...) + if err != nil { + return &Row{err: err} + } + return &Row{rows: rows} +} + +// Close closes the statement. +func (s *Stmt) Close() error { + if s.stickyErr != nil { + return s.stickyErr + } + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return nil + } + s.closed = true + + if s.tx != nil { + s.txsi.Close() + } else { + for _, v := range s.css { + if ci, match := s.db.connIfFree(v.ci); match { + v.si.Close() + s.db.putConn(ci, nil) + } else { + // TODO(bradfitz): care that we can't close + // this statement because the statement's + // connection is in use? + } + } + } + return nil +} + +// Rows is the result of a query. Its cursor starts before the first row +// of the result set. Use Next to advance through the rows: +// +// rows, err := db.Query("SELECT ...") +// ... +// for rows.Next() { +// var id int +// var name string +// err = rows.Scan(&id, &name) +// ... +// } +// err = rows.Err() // get any error encountered during iteration +// ... +type Rows struct { + db *DB + ci driver.Conn // owned; must call putconn when closed to release + releaseConn func(error) + rowsi driver.Rows + + closed bool + lastcols []driver.Value + lasterr error + closeStmt *Stmt // if non-nil, statement to Close on close +} + +// Next prepares the next result row for reading with the Scan method. +// It returns true on success, false if there is no next result row. +// Every call to Scan, even the first one, must be preceded by a call +// to Next. +func (rs *Rows) Next() bool { + if rs.closed { + return false + } + if rs.lasterr != nil { + return false + } + if rs.lastcols == nil { + rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) + } + rs.lasterr = rs.rowsi.Next(rs.lastcols) + if rs.lasterr == io.EOF { + rs.Close() + } + return rs.lasterr == nil +} + +// Err returns the error, if any, that was encountered during iteration. +func (rs *Rows) Err() error { + if rs.lasterr == io.EOF { + return nil + } + return rs.lasterr +} + +// Columns returns the column names. +// Columns returns an error if the rows are closed, or if the rows +// are from QueryRow and there was a deferred error. +func (rs *Rows) Columns() ([]string, error) { + if rs.closed { + return nil, errors.New("sql: Rows are closed") + } + if rs.rowsi == nil { + return nil, errors.New("sql: no Rows available") + } + return rs.rowsi.Columns(), nil +} + +// Scan copies the columns in the current row into the values pointed +// at by dest. +// +// If an argument has type *[]byte, Scan saves in that argument a copy +// of the corresponding data. The copy is owned by the caller and can +// be modified and held indefinitely. The copy can be avoided by using +// an argument of type *RawBytes instead; see the documentation for +// RawBytes for restrictions on its use. +// +// If an argument has type *interface{}, Scan copies the value +// provided by the underlying driver without conversion. If the value +// is of type []byte, a copy is made and the caller owns the result. +func (rs *Rows) Scan(dest ...interface{}) error { + if rs.closed { + return errors.New("sql: Rows closed") + } + if rs.lasterr != nil { + return rs.lasterr + } + if rs.lastcols == nil { + return errors.New("sql: Scan called without calling Next") + } + if len(dest) != len(rs.lastcols) { + return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) + } + for i, sv := range rs.lastcols { + err := convertAssign(dest[i], sv) + if err != nil { + return fmt.Errorf("sql: Scan error on column index %d: %v", i, err) + } + } + for _, dp := range dest { + b, ok := dp.(*[]byte) + if !ok { + continue + } + if *b == nil { + // If the []byte is now nil (for a NULL value), + // don't fall through to below which would + // turn it into a non-nil 0-length byte slice + continue + } + if _, ok = dp.(*RawBytes); ok { + continue + } + clone := make([]byte, len(*b)) + copy(clone, *b) + *b = clone + } + return nil +} + +// Close closes the Rows, preventing further enumeration. If the +// end is encountered, the Rows are closed automatically. Close +// is idempotent. +func (rs *Rows) Close() error { + if rs.closed { + return nil + } + rs.closed = true + err := rs.rowsi.Close() + rs.releaseConn(err) + if rs.closeStmt != nil { + rs.closeStmt.Close() + } + return err +} + +// Row is the result of calling QueryRow to select a single row. +type Row struct { + // One of these two will be non-nil: + err error // deferred error for easy chaining + rows *Rows +} + +// Scan copies the columns from the matched row into the values +// pointed at by dest. If more than one row matches the query, +// Scan uses the first row and discards the rest. If no row matches +// the query, Scan returns ErrNoRows. +func (r *Row) Scan(dest ...interface{}) error { + if r.err != nil { + return r.err + } + + // TODO(bradfitz): for now we need to defensively clone all + // []byte that the driver returned (not permitting + // *RawBytes in Rows.Scan), since we're about to close + // the Rows in our defer, when we return from this function. + // the contract with the driver.Next(...) interface is that it + // can return slices into read-only temporary memory that's + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We + // should provide an optional interface for drivers to + // implement to say, "don't worry, the []bytes that I return + // from Next will not be modified again." (for instance, if + // they were obtained from the network anyway) But for now we + // don't care. + for _, dp := range dest { + if _, ok := dp.(*RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + defer r.rows.Close() + if !r.rows.Next() { + return ErrNoRows + } + err := r.rows.Scan(dest...) + if err != nil { + return err + } + + return nil +} + +// A Result summarizes an executed SQL command. +type Result interface { + LastInsertId() (int64, error) + RowsAffected() (int64, error) +} + +type result struct { + driver.Result +} diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go new file mode 100644 index 000000000..b29670586 --- /dev/null +++ b/src/pkg/database/sql/sql_test.go @@ -0,0 +1,614 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "database/sql/driver" + "fmt" + "reflect" + "runtime" + "strings" + "testing" + "time" +) + +func init() { + type dbConn struct { + db *DB + c driver.Conn + } + freedFrom := make(map[dbConn]string) + putConnHook = func(db *DB, c driver.Conn) { + for _, oc := range db.freeConn { + if oc == c { + // print before panic, as panic may get lost due to conflicting panic + // (all goroutines asleep) elsewhere, since we might not unlock + // the mutex in freeConn here. + println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack()) + panic("double free of conn.") + } + } + freedFrom[dbConn{db, c}] = stack() + } +} + +const fakeDBName = "foo" + +var chrisBirthday = time.Unix(123456789, 0) + +func newTestDB(t *testing.T, name string) *DB { + db, err := Open("test", fakeDBName) + if err != nil { + t.Fatalf("Open: %v", err) + } + if _, err := db.Exec("WIPE"); err != nil { + t.Fatalf("exec wipe: %v", err) + } + if name == "people" { + exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime") + exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1) + exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2) + exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) + } + return db +} + +func exec(t *testing.T, db *DB, query string, args ...interface{}) { + _, err := db.Exec(query, args...) + if err != nil { + t.Fatalf("Exec of %q: %v", query, err) + } +} + +func closeDB(t *testing.T, db *DB) { + err := db.Close() + if err != nil { + t.Fatalf("error closing DB: %v", err) + } +} + +// numPrepares assumes that db has exactly 1 idle conn and returns +// its count of calls to Prepare +func numPrepares(t *testing.T, db *DB) int { + if n := len(db.freeConn); n != 1 { + t.Fatalf("free conns = %d; want 1", n) + } + return db.freeConn[0].(*fakeConn).numPrepare +} + +func TestQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + {age: 3, name: "Chris"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := len(db.freeConn); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestByteOwnership(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|name,photo|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + name []byte + photo RawBytes + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.name, &r.photo) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + corruptMemory := []byte("\xffPHOTO") + want := []row{ + {name: []byte("Alice"), photo: corruptMemory}, + {name: []byte("Bob"), photo: corruptMemory}, + {name: []byte("Chris"), photo: corruptMemory}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + var photo RawBytes + err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo) + if err == nil { + t.Error("want error scanning into RawBytes from QueryRow") + } +} + +func TestRowsColumns(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + cols, err := rows.Columns() + if err != nil { + t.Fatalf("Columns: %v", err) + } + want := []string{"age", "name"} + if !reflect.DeepEqual(cols, want) { + t.Errorf("got %#v; want %#v", cols, want) + } +} + +func TestQueryRow(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + var name string + var age int + var birthday time.Time + + err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age) + if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") { + t.Errorf("expected error from wrong number of arguments; actually got: %v", err) + } + + err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday) + if err != nil || !birthday.Equal(chrisBirthday) { + t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday) + } + + err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name) + if err != nil { + t.Fatalf("age QueryRow+Scan: %v", err) + } + if name != "Bob" { + t.Errorf("expected name Bob, got %q", name) + } + if age != 2 { + t.Errorf("expected age 2, got %d", age) + } + + err = db.QueryRow("SELECT|people|age,name|name=?", "Alice").Scan(&age, &name) + if err != nil { + t.Fatalf("name QueryRow+Scan: %v", err) + } + if name != "Alice" { + t.Errorf("expected name Alice, got %q", name) + } + if age != 1 { + t.Errorf("expected age 1, got %d", age) + } + + var photo []byte + err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo) + if err != nil { + t.Fatalf("photo QueryRow+Scan: %v", err) + } + want := []byte("APHOTO") + if !reflect.DeepEqual(photo, want) { + t.Errorf("photo = %q; want %q", photo, want) + } +} + +func TestStatementErrorAfterClose(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + err = stmt.Close() + if err != nil { + t.Fatalf("Close: %v", err) + } + var name string + err = stmt.QueryRow("foo").Scan(&name) + if err == nil { + t.Errorf("expected error from QueryRow.Scan after Stmt.Close") + } +} + +func TestStatementQueryRow(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + defer stmt.Close() + var age int + for n, tt := range []struct { + name string + want int + }{ + {"Alice", 1}, + {"Bob", 2}, + {"Chris", 3}, + } { + if err := stmt.QueryRow(tt.name).Scan(&age); err != nil { + t.Errorf("%d: on %q, QueryRow/Scan: %v", n, tt.name, err) + } else if age != tt.want { + t.Errorf("%d: age=%d, want %d", n, age, tt.want) + } + } + +} + +// just a test of fakedb itself +func TestBogusPreboundParameters(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion") + if err == nil { + t.Fatalf("expected error") + } + if err.Error() != `fakedb: invalid conversion to int32 from "bogusconversion"` { + t.Errorf("unexpected error: %v", err) + } +} + +func TestExec(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Errorf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + + type execTest struct { + args []interface{} + wantErr string + } + execTests := []execTest{ + // Okay: + {[]interface{}{"Brad", 31}, ""}, + {[]interface{}{"Brad", int64(31)}, ""}, + {[]interface{}{"Bob", "32"}, ""}, + {[]interface{}{7, 9}, ""}, + + // Invalid conversions: + {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting Exec argument #1's type: sql/driver: value 4294967295 overflows int32"}, + {[]interface{}{"Brad", "strconv fail"}, "sql: converting Exec argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"}, + + // Wrong number of args: + {[]interface{}{}, "sql: expected 2 arguments, got 0"}, + {[]interface{}{1, 2, 3}, "sql: expected 2 arguments, got 3"}, + } + for n, et := range execTests { + _, err := stmt.Exec(et.args...) + errStr := "" + if err != nil { + errStr = err.Error() + } + if errStr != et.wantErr { + t.Errorf("stmt.Execute #%d: for %v, got error %q, want error %q", + n, et.args, errStr, et.wantErr) + } + } +} + +func TestTxStmt(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs := tx.Stmt(stmt) + defer txs.Close() + _, err = txs.Exec("Bobby", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } +} + +// Issue: http://golang.org/issue/2784 +// This test didn't fail before because we got luckly with the fakedb driver. +// It was failing, and now not, in github.com/bradfitz/go-sql-test +func TestTxQuery(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + exec(t, db, "INSERT|t1|name=Alice") + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + r, err := tx.Query("SELECT|t1|name|") + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(r.Err()) + } + t.Fatal("expected one row") + } + + var x string + err = r.Scan(&x) + if err != nil { + t.Fatal(err) + } +} + +func TestTxQueryInvalid(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + _, err = tx.Query("SELECT|t1|name|") + if err == nil { + t.Fatal("Error expected") + } +} + +// Tests fix for issue 2542, that we release a lock when querying on +// a closed connection. +func TestIssue2542Deadlock(t *testing.T) { + db := newTestDB(t, "people") + closeDB(t, db) + for i := 0; i < 2; i++ { + _, err := db.Query("SELECT|people|age,name|") + if err == nil { + t.Fatalf("expected error") + } + } +} + +// Tests fix for issue 2788, that we bind nil to a []byte if the +// value in the column is sql null +func TestNullByteSlice(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t|id=int32,name=nullstring") + exec(t, db, "INSERT|t|id=10,name=?", nil) + + var name []byte + + err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name) + if err != nil { + t.Fatal(err) + } + if name != nil { + t.Fatalf("name []byte should be nil for null column value, got: %#v", name) + } + + exec(t, db, "INSERT|t|id=11,name=?", "bob") + err = db.QueryRow("SELECT|t|name|id=?", 11).Scan(&name) + if err != nil { + t.Fatal(err) + } + if string(name) != "bob" { + t.Fatalf("name []byte should be bob, got: %q", string(name)) + } +} + +func TestPointerParamsAndScans(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t|id=int32,name=nullstring") + + bob := "bob" + var name *string + + name = &bob + exec(t, db, "INSERT|t|id=10,name=?", name) + name = nil + exec(t, db, "INSERT|t|id=20,name=?", name) + + err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name) + if err != nil { + t.Fatalf("querying id 10: %v", err) + } + if name == nil { + t.Errorf("id 10's name = nil; want bob") + } else if *name != "bob" { + t.Errorf("id 10's name = %q; want bob", *name) + } + + err = db.QueryRow("SELECT|t|name|id=?", 20).Scan(&name) + if err != nil { + t.Fatalf("querying id 20: %v", err) + } + if name != nil { + t.Errorf("id 20 = %q; want nil", *name) + } +} + +func TestQueryRowClosingStmt(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + var name string + var age int + err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name) + if err != nil { + t.Fatal(err) + } + if len(db.freeConn) != 1 { + t.Fatalf("expected 1 free conn") + } + fakeConn := db.freeConn[0].(*fakeConn) + if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed { + t.Errorf("statement close mismatch: made %d, closed %d", made, closed) + } +} + +type nullTestRow struct { + nullParam interface{} + notNullParam interface{} + scanNullVal interface{} +} + +type nullTestSpec struct { + nullType string + notNullType string + rows [6]nullTestRow +} + +func TestNullStringParam(t *testing.T) { + spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{ + {NullString{"aqua", true}, "", NullString{"aqua", true}}, + {NullString{"brown", false}, "", NullString{"", false}}, + {"chartreuse", "", NullString{"chartreuse", true}}, + {NullString{"darkred", true}, "", NullString{"darkred", true}}, + {NullString{"eel", false}, "", NullString{"", false}}, + {"foo", NullString{"black", false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullInt64Param(t *testing.T) { + spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{ + {NullInt64{31, true}, 1, NullInt64{31, true}}, + {NullInt64{-22, false}, 1, NullInt64{0, false}}, + {22, 1, NullInt64{22, true}}, + {NullInt64{33, true}, 1, NullInt64{33, true}}, + {NullInt64{222, false}, 1, NullInt64{0, false}}, + {0, NullInt64{31, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullFloat64Param(t *testing.T) { + spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{ + {NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}}, + {NullFloat64{13.1, false}, 1, NullFloat64{0, false}}, + {-22.9, 1, NullFloat64{-22.9, true}}, + {NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}}, + {NullFloat64{222, false}, 1, NullFloat64{0, false}}, + {10, NullFloat64{31.2, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullBoolParam(t *testing.T) { + spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{ + {NullBool{false, true}, true, NullBool{false, true}}, + {NullBool{true, false}, false, NullBool{false, false}}, + {true, true, NullBool{true, true}}, + {NullBool{true, true}, false, NullBool{true, true}}, + {NullBool{true, false}, true, NullBool{false, false}}, + {true, NullBool{true, false}, nil}, + }} + nullTestRun(t, spec) +} + +func nullTestRun(t *testing.T, spec nullTestSpec) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, fmt.Sprintf("CREATE|t|id=int32,name=string,nullf=%s,notnullf=%s", spec.nullType, spec.notNullType)) + + // Inserts with db.Exec: + exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 1, "alice", spec.rows[0].nullParam, spec.rows[0].notNullParam) + exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 2, "bob", spec.rows[1].nullParam, spec.rows[1].notNullParam) + + // Inserts with a prepared statement: + stmt, err := db.Prepare("INSERT|t|id=?,name=?,nullf=?,notnullf=?") + if err != nil { + t.Fatalf("prepare: %v", err) + } + defer stmt.Close() + if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil { + t.Errorf("exec insert chris: %v", err) + } + if _, err := stmt.Exec(4, "dave", spec.rows[3].nullParam, spec.rows[3].notNullParam); err != nil { + t.Errorf("exec insert dave: %v", err) + } + if _, err := stmt.Exec(5, "eleanor", spec.rows[4].nullParam, spec.rows[4].notNullParam); err != nil { + t.Errorf("exec insert eleanor: %v", err) + } + + // Can't put null val into non-null col + if _, err := stmt.Exec(6, "bob", spec.rows[5].nullParam, spec.rows[5].notNullParam); err == nil { + t.Errorf("expected error inserting nil val with prepared statement Exec") + } + + _, err = db.Exec("INSERT|t|id=?,name=?,nullf=?", 999, nil, nil) + if err == nil { + // TODO: this test fails, but it's just because + // fakeConn implements the optional Execer interface, + // so arguably this is the correct behavior. But + // maybe I should flesh out the fakeConn.Exec + // implementation so this properly fails. + // t.Errorf("expected error inserting nil name with Exec") + } + + paramtype := reflect.TypeOf(spec.rows[0].nullParam) + bindVal := reflect.New(paramtype).Interface() + + for i := 0; i < 5; i++ { + id := i + 1 + if err := db.QueryRow("SELECT|t|nullf|id=?", id).Scan(bindVal); err != nil { + t.Errorf("id=%d Scan: %v", id, err) + } + bindValDeref := reflect.ValueOf(bindVal).Elem().Interface() + if !reflect.DeepEqual(bindValDeref, spec.rows[i].scanNullVal) { + t.Errorf("id=%d got %#v, want %#v", id, bindValDeref, spec.rows[i].scanNullVal) + } + } +} + +func stack() string { + buf := make([]byte, 1024) + return string(buf[:runtime.Stack(buf, false)]) +} |