diff --git a/README.md b/README.md index 810be02..d4151f4 100644 --- a/README.md +++ b/README.md @@ -222,6 +222,7 @@ It only asserts that argument is of `time.Time` type. ## Change Log +- **2019-04-06** - added functionality to mock a sql MetaData request - **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`. - **2018-12-11** - added expectation of Rows to be closed, while mocking expected query. - **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching. diff --git a/column.go b/column.go new file mode 100644 index 0000000..e418d2e --- /dev/null +++ b/column.go @@ -0,0 +1,77 @@ +package sqlmock + +import "reflect" + +// Column is a mocked column Metadata for rows.ColumnTypes() +type Column struct { + name string + dbType string + nullable bool + nullableOk bool + length int64 + lengthOk bool + precision int64 + scale int64 + psOk bool + scanType reflect.Type +} + +func (c *Column) Name() string { + return c.name +} + +func (c *Column) DbType() string { + return c.dbType +} + +func (c *Column) IsNullable() (bool, bool) { + return c.nullable, c.nullableOk +} + +func (c *Column) Length() (int64, bool) { + return c.length, c.lengthOk +} + +func (c *Column) PrecisionScale() (int64, int64, bool) { + return c.precision, c.scale, c.psOk +} + +func (c *Column) ScanType() reflect.Type { + return c.scanType +} + +// NewColumn returns a Column with specified name +func NewColumn(name string) *Column { + return &Column{ + name: name, + } +} + +// Nullable returns the column with nullable metadata set +func (c *Column) Nullable(nullable bool) *Column { + c.nullable = nullable + c.nullableOk = true + return c +} + +// OfType returns the column with type metadata set +func (c *Column) OfType(dbType string, sampleValue interface{}) *Column { + c.dbType = dbType + c.scanType = reflect.TypeOf(sampleValue) + return c +} + +// WithLength returns the column with length metadata set. +func (c *Column) WithLength(length int64) *Column { + c.length = length + c.lengthOk = true + return c +} + +// WithPrecisionAndScale returns the column with precision and scale metadata set. +func (c *Column) WithPrecisionAndScale(precision, scale int64) *Column { + c.precision = precision + c.scale = scale + c.psOk = true + return c +} diff --git a/column_test.go b/column_test.go new file mode 100644 index 0000000..0311216 --- /dev/null +++ b/column_test.go @@ -0,0 +1,63 @@ +package sqlmock + +import ( + "reflect" + "testing" + "time" +) + +func TestColumn(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z") + column1 := NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100) + column2 := NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4) + column3 := NewColumn("when").OfType("TIMESTAMP", now) + + if column1.ScanType().Kind() != reflect.String { + t.Errorf("string scanType mismatch: %v", column1.ScanType()) + } + if column2.ScanType().Kind() != reflect.Float64 { + t.Errorf("float scanType mismatch: %v", column2.ScanType()) + } + if column3.ScanType() != reflect.TypeOf(time.Time{}) { + t.Errorf("time scanType mismatch: %v", column3.ScanType()) + } + + nullable, ok := column1.IsNullable() + if !nullable || !ok { + t.Errorf("'test' column should be nullable") + } + nullable, ok = column2.IsNullable() + if nullable || !ok { + t.Errorf("'number' column should not be nullable") + } + nullable, ok = column3.IsNullable() + if ok { + t.Errorf("'when' column nullability should be unknown") + } + + length, ok := column1.Length() + if length != 100 || !ok { + t.Errorf("'test' column wrong length") + } + length, ok = column2.Length() + if ok { + t.Errorf("'number' column is not of variable length type") + } + length, ok = column3.Length() + if ok { + t.Errorf("'when' column is not of variable length type") + } + + _, _, ok = column1.PrecisionScale() + if ok { + t.Errorf("'test' column not applicable") + } + precision, scale, ok := column2.PrecisionScale() + if precision != 10 || scale != 4 || !ok { + t.Errorf("'number' column not applicable") + } + _, _, ok = column3.PrecisionScale() + if ok { + t.Errorf("'when' column not applicable") + } +} diff --git a/expectations_go18.go b/expectations_go18.go index 172bb6c..6b85ce1 100644 --- a/expectations_go18.go +++ b/expectations_go18.go @@ -12,11 +12,19 @@ import ( // WillReturnRows specifies the set of resulting rows that will be returned // by the triggered query func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { + defs := 0 sets := make([]*Rows, len(rows)) for i, r := range rows { sets[i] = r + if r.def != nil { + defs++ + } + } + if defs > 0 && defs == len(sets) { + e.rows = &rowSetsWithDefinition{&rowSets{sets: sets, ex: e}} + } else { + e.rows = &rowSets{sets: sets, ex: e} } - e.rows = &rowSets{sets: sets, ex: e} return e } diff --git a/rows.go b/rows.go index 5f11c78..ccc5f0c 100644 --- a/rows.go +++ b/rows.go @@ -120,6 +120,7 @@ func (rs *rowSets) invalidateRaw() { type Rows struct { converter driver.ValueConverter cols []string + def []*Column rows [][]driver.Value pos int nextErr map[int]error diff --git a/rows_go18.go b/rows_go18.go index 4ecf84e..6c71eb9 100644 --- a/rows_go18.go +++ b/rows_go18.go @@ -2,7 +2,11 @@ package sqlmock -import "io" +import ( + "database/sql/driver" + "io" + "reflect" +) // Implement the "RowsNextResultSet" interface func (rs *rowSets) HasNextResultSet() bool { @@ -18,3 +22,53 @@ func (rs *rowSets) NextResultSet() error { rs.pos++ return nil } + +// type for rows with columns definition created with sqlmock.NewRowsWithColumnDefinition +type rowSetsWithDefinition struct { + *rowSets +} + +// Implement the "RowsColumnTypeDatabaseTypeName" interface +func (rs *rowSetsWithDefinition) ColumnTypeDatabaseTypeName(index int) string { + return rs.getDefinition(index).DbType() +} + +// Implement the "RowsColumnTypeLength" interface +func (rs *rowSetsWithDefinition) ColumnTypeLength(index int) (length int64, ok bool) { + return rs.getDefinition(index).Length() +} + +// Implement the "RowsColumnTypeNullable" interface +func (rs *rowSetsWithDefinition) ColumnTypeNullable(index int) (nullable, ok bool) { + return rs.getDefinition(index).IsNullable() +} + +// Implement the "RowsColumnTypePrecisionScale" interface +func (rs *rowSetsWithDefinition) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return rs.getDefinition(index).PrecisionScale() +} + +// ColumnTypeScanType is defined from driver.RowsColumnTypeScanType +func (rs *rowSetsWithDefinition) ColumnTypeScanType(index int) reflect.Type { + return rs.getDefinition(index).ScanType() +} + +// return column definition from current set metadata +func (rs *rowSetsWithDefinition) getDefinition(index int) *Column { + return rs.sets[rs.pos].def[index] +} + +// NewRowsWithColumnDefinition return rows with columns metadata +func NewRowsWithColumnDefinition(columns ...*Column) *Rows { + cols := make([]string, len(columns)) + for i, column := range columns { + cols[i] = column.Name() + } + + return &Rows{ + cols: cols, + def: columns, + nextErr: make(map[int]error), + converter: driver.DefaultParameterConverter, + } +} diff --git a/rows_go18_test.go b/rows_go18_test.go index b29a2c5..0af6d66 100644 --- a/rows_go18_test.go +++ b/rows_go18_test.go @@ -6,7 +6,9 @@ import ( "database/sql" "encoding/json" "fmt" + "reflect" "testing" + "time" ) func TestQueryMultiRows(t *testing.T) { @@ -203,3 +205,183 @@ func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *tes } queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`)) } + +func TestNewColumnWithDefinition(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z") + + t.Run("with one ResultSet", func(t *testing.T) { + db, mock, _ := New() + column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100) + column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4) + column3 := mock.NewColumn("when").OfType("TIMESTAMP", now) + rows := mock.NewRowsWithColumnDefinition(column1, column2, column3) + rows.AddRow("foo.bar", float64(10.123), now) + + mQuery := mock.ExpectQuery("SELECT test, number, when from dummy") + isQuery := mQuery.WillReturnRows(rows) + isQueryClosed := mQuery.RowsWillBeClosed() + isDbClosed := mock.ExpectClose() + + query, _ := db.Query("SELECT test, number, when from dummy") + + if false == isQuery.fulfilled() { + t.Error("Query is not executed") + } + + if query.Next() { + var test string + var number float64 + var when time.Time + + if queryError := query.Scan(&test, &number, &when); queryError != nil { + t.Error(queryError) + } else if test != "foo.bar" { + t.Error("field test is not 'foo.bar'") + } else if number != float64(10.123) { + t.Error("field number is not '10.123'") + } else if when != now { + t.Errorf("field when is not %v", now) + } + + if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil { + t.Error(colTypErr) + } else if len(columnTypes) != 3 { + t.Error("number of columnTypes") + } else if name := columnTypes[0].Name(); name != "test" { + t.Errorf("field 'test' has a wrong name '%s'", name) + } else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" { + t.Errorf("field 'test' has a wrong db type '%s'", dbType) + } else if columnTypes[0].ScanType().Kind() != reflect.String { + t.Error("field 'test' has a wrong scanType") + } else if _, _, ok := columnTypes[0].DecimalSize(); ok { + t.Error("field 'test' should have not precision, scale") + } else if length, ok := columnTypes[0].Length(); length != 100 || !ok { + t.Errorf("field 'test' has a wrong length '%d'", length) + } else if name := columnTypes[1].Name(); name != "number" { + t.Errorf("field 'number' has a wrong name '%s'", name) + } else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" { + t.Errorf("field 'number' has a wrong db type '%s'", dbType) + } else if columnTypes[1].ScanType().Kind() != reflect.Float64 { + t.Error("field 'number' has a wrong scanType") + } else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok { + t.Error("field 'number' has a wrong precision, scale") + } else if _, ok := columnTypes[1].Length(); ok { + t.Error("field 'number' is not variable length type") + } else if _, ok := columnTypes[2].Nullable(); ok { + t.Error("field 'when' should have nullability unknown") + } + } else { + t.Error("no result set") + } + + query.Close() + if false == isQueryClosed.fulfilled() { + t.Error("Query is not executed") + } + + db.Close() + if false == isDbClosed.fulfilled() { + t.Error("Db is not closed") + } + }) + + t.Run("with more then one ResultSet", func(t *testing.T) { + db, mock, _ := New() + column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100) + column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4) + column3 := mock.NewColumn("when").OfType("TIMESTAMP", now) + rows1 := mock.NewRowsWithColumnDefinition(column1, column2, column3) + rows1.AddRow("foo.bar", float64(10.123), now) + rows2 := mock.NewRowsWithColumnDefinition(column1, column2, column3) + rows2.AddRow("bar.foo", float64(123.10), now.Add(time.Second*10)) + rows3 := mock.NewRowsWithColumnDefinition(column1, column2, column3) + rows3.AddRow("lollipop", float64(10.321), now.Add(time.Second*20)) + + mQuery := mock.ExpectQuery("SELECT test, number, when from dummy") + isQuery := mQuery.WillReturnRows(rows1, rows2, rows3) + isQueryClosed := mQuery.RowsWillBeClosed() + isDbClosed := mock.ExpectClose() + + query, _ := db.Query("SELECT test, number, when from dummy") + + if false == isQuery.fulfilled() { + t.Error("Query is not executed") + } + + rowsSi := 0 + + for query.Next() { + var test string + var number float64 + var when time.Time + + if queryError := query.Scan(&test, &number, &when); queryError != nil { + t.Error(queryError) + + } else if rowsSi == 0 && test != "foo.bar" { + t.Error("field test is not 'foo.bar'") + } else if rowsSi == 0 && number != float64(10.123) { + t.Error("field number is not '10.123'") + } else if rowsSi == 0 && when != now { + t.Errorf("field when is not %v", now) + + } else if rowsSi == 1 && test != "bar.foo" { + t.Error("field test is not 'bar.bar'") + } else if rowsSi == 1 && number != float64(123.10) { + t.Error("field number is not '123.10'") + } else if rowsSi == 1 && when != now.Add(time.Second*10) { + t.Errorf("field when is not %v", now) + + } else if rowsSi == 2 && test != "lollipop" { + t.Error("field test is not 'lollipop'") + } else if rowsSi == 2 && number != float64(10.321) { + t.Error("field number is not '10.321'") + } else if rowsSi == 2 && when != now.Add(time.Second*20) { + t.Errorf("field when is not %v", now) + } + + rowsSi++ + + if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil { + t.Error(colTypErr) + } else if len(columnTypes) != 3 { + t.Error("number of columnTypes") + } else if name := columnTypes[0].Name(); name != "test" { + t.Errorf("field 'test' has a wrong name '%s'", name) + } else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" { + t.Errorf("field 'test' has a wrong db type '%s'", dbType) + } else if columnTypes[0].ScanType().Kind() != reflect.String { + t.Error("field 'test' has a wrong scanType") + } else if _, _, ok := columnTypes[0].DecimalSize(); ok { + t.Error("field 'test' should not have precision, scale") + } else if length, ok := columnTypes[0].Length(); length != 100 || !ok { + t.Errorf("field 'test' has a wrong length '%d'", length) + } else if name := columnTypes[1].Name(); name != "number" { + t.Errorf("field 'number' has a wrong name '%s'", name) + } else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" { + t.Errorf("field 'number' has a wrong db type '%s'", dbType) + } else if columnTypes[1].ScanType().Kind() != reflect.Float64 { + t.Error("field 'number' has a wrong scanType") + } else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok { + t.Error("field 'number' has a wrong precision, scale") + } else if _, ok := columnTypes[1].Length(); ok { + t.Error("field 'number' is not variable length type") + } else if _, ok := columnTypes[2].Nullable(); ok { + t.Error("field 'when' should have nullability unknown") + } + } + if rowsSi == 0 { + t.Error("no result set") + } + + query.Close() + if false == isQueryClosed.fulfilled() { + t.Error("Query is not executed") + } + + db.Close() + if false == isDbClosed.fulfilled() { + t.Error("Db is not closed") + } + }) +} diff --git a/sqlmock.go b/sqlmock.go index 90f789b..d074266 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -20,7 +20,7 @@ import ( // Sqlmock interface serves to create expectations // for any kind of database action in order to mock // and test real database behavior. -type Sqlmock interface { +type SqlmockCommon interface { // ExpectClose queues an expectation for this database // action to be triggered. the *ExpectedClose allows // to mock database response diff --git a/sqlmock_before_go18.go b/sqlmock_before_go18.go index 1a5b63a..9965e78 100644 --- a/sqlmock_before_go18.go +++ b/sqlmock_before_go18.go @@ -9,6 +9,12 @@ import ( "time" ) +// Sqlmock interface for Go up to 1.7 +type Sqlmock interface { + // Embed common methods + SqlmockCommon +} + type namedValue struct { Name string Ordinal int diff --git a/sqlmock_go18.go b/sqlmock_go18.go index dc37b18..f268900 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -11,6 +11,19 @@ import ( "time" ) +// Sqlmock interface for Go 1.8+ +type Sqlmock interface { + // Embed common methods + SqlmockCommon + + // NewRowsWithColumnDefinition allows Rows to be created from a + // sql driver.Value slice with a definition of sql metadata + NewRowsWithColumnDefinition(columns ...*Column) *Rows + + // New Column allows to create a Column + NewColumn(name string) *Column +} + // ErrCancelled defines an error value, which can be expected in case of // such cancellation error. var ErrCancelled = errors.New("canceling query due to user request") @@ -327,3 +340,17 @@ func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, e } // @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) + +// NewRowsWithColumnDefinition allows Rows to be created from a +// sql driver.Value slice with a definition of sql metadata +func (c *sqlmock) NewRowsWithColumnDefinition(columns ...*Column) *Rows { + r := NewRowsWithColumnDefinition(columns...) + r.converter = c.converter + return r +} + +// NewColumn allows to create a Column that can be enhanced with metadata +// using OfType/Nullable/WithLength/WithPrecisionAndScale methods. +func (c *sqlmock) NewColumn(name string) *Column { + return NewColumn(name) +}