From d11f623794a97c1d6b50814ac3d60a0959b85f68 Mon Sep 17 00:00:00 2001 From: gedi Date: Mon, 6 Feb 2017 14:38:57 +0200 Subject: [PATCH 1/9] implements named arguments support and adds delay expectation for context deadline simulation --- expectations.go | 59 ++++++++++++++++++++++--- expectations_test.go | 100 ++++++++++++++++++++++++++++++++++++++----- sqlmock.go | 51 +++++++++++++++++++--- sqlmock_go18.go | 5 +++ 4 files changed, 191 insertions(+), 24 deletions(-) create mode 100644 sqlmock_go18.go diff --git a/expectations.go b/expectations.go index 5b6865e..3902ff3 100644 --- a/expectations.go +++ b/expectations.go @@ -7,6 +7,7 @@ import ( "regexp" "strings" "sync" + "time" ) // an expectation interface @@ -54,6 +55,7 @@ func (e *ExpectedClose) String() string { // returned by *Sqlmock.ExpectBegin. type ExpectedBegin struct { commonExpectation + delay time.Duration } // WillReturnError allows to set an error for *sql.DB.Begin action @@ -71,6 +73,13 @@ func (e *ExpectedBegin) String() string { return msg } +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin { + e.delay = duration + return e +} + // ExpectedCommit is used to manage *sql.Tx.Commit expectation // returned by *Sqlmock.ExpectCommit. type ExpectedCommit struct { @@ -118,7 +127,8 @@ func (e *ExpectedRollback) String() string { // Returned by *Sqlmock.ExpectQuery. type ExpectedQuery struct { queryBasedExpectation - rows driver.Rows + rows driver.Rows + delay time.Duration } // WithArgs will match given expected args to actual database query arguments. @@ -142,6 +152,13 @@ func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery { return e } +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { + e.delay = duration + return e +} + // String returns string representation func (e *ExpectedQuery) String() string { msg := "ExpectedQuery => expecting Query or QueryRow which:" @@ -178,6 +195,7 @@ func (e *ExpectedQuery) String() string { type ExpectedExec struct { queryBasedExpectation result driver.Result + delay time.Duration } // WithArgs will match given expected args to actual database exec operation arguments. @@ -194,6 +212,13 @@ func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { return e } +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec { + e.delay = duration + return e +} + // String returns string representation func (e *ExpectedExec) String() string { msg := "ExpectedExec => expecting Exec which:" @@ -244,6 +269,7 @@ type ExpectedPrepare struct { sqlRegex *regexp.Regexp statement driver.Stmt closeErr error + delay time.Duration } // WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. @@ -258,6 +284,13 @@ func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { return e } +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare { + e.delay = duration + return e +} + // ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. // this method is convenient in order to prevent duplicating sql query string matching. func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { @@ -300,7 +333,7 @@ type queryBasedExpectation struct { args []driver.Value } -func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (err error) { +func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) { if !e.queryMatches(sql) { return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String()) } @@ -323,7 +356,7 @@ func (e *queryBasedExpectation) queryMatches(sql string) bool { return e.sqlRegex.MatchString(sql) } -func (e *queryBasedExpectation) argsMatches(args []driver.Value) error { +func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { return nil } @@ -334,14 +367,26 @@ func (e *queryBasedExpectation) argsMatches(args []driver.Value) error { // custom argument matcher matcher, ok := e.args[k].(Argument) if ok { - if !matcher.Match(v) { + // @TODO: does it make sense to pass value instead of named value? + if !matcher.Match(v.Value) { return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) } continue } + dval := e.args[k] + if named, isNamed := dval.(namedValue); isNamed { + dval = named.Value + if v.Name != named.Name { + return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) + } + if v.Ordinal != named.Ordinal { + return fmt.Errorf("named argument %d: ordinal position: \"%d\" does not match expected: \"%d\"", k, v.Ordinal, named.Ordinal) + } + } + // convert to driver converter - darg, err := driver.DefaultParameterConverter.ConvertValue(e.args[k]) + darg, err := driver.DefaultParameterConverter.ConvertValue(dval) if err != nil { return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) } @@ -350,8 +395,8 @@ func (e *queryBasedExpectation) argsMatches(args []driver.Value) error { return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) } - if !reflect.DeepEqual(darg, args[k]) { - return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, args[k], args[k]) + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) } } return nil diff --git a/expectations_test.go b/expectations_test.go index 032f029..6238532 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -10,29 +10,38 @@ import ( func TestQueryExpectationArgComparison(t *testing.T) { e := &queryBasedExpectation{} - against := []driver.Value{int64(5)} + against := []namedValue{{Value: int64(5), Ordinal: 1}} if err := e.argsMatches(against); err != nil { t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) } e.args = []driver.Value{5, "str"} - against = []driver.Value{int64(5)} + against = []namedValue{{Value: int64(5), Ordinal: 1}} if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since the size is not the same") } - against = []driver.Value{int64(3), "str"} + against = []namedValue{ + {Value: int64(3), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since the first argument (int value) is different") } - against = []driver.Value{int64(5), "st"} + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "st", Ordinal: 2}, + } if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since the second argument (string value) is different") } - against = []driver.Value{int64(5), "str"} + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } if err := e.argsMatches(against); err != nil { t.Errorf("arguments should match, but it did not: %s", err) } @@ -41,7 +50,10 @@ func TestQueryExpectationArgComparison(t *testing.T) { tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") e.args = []driver.Value{5, tm} - against = []driver.Value{int64(5), tm} + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: tm, Ordinal: 2}, + } if err := e.argsMatches(against); err != nil { t.Error("arguments should match, but it did not") } @@ -52,29 +64,95 @@ func TestQueryExpectationArgComparison(t *testing.T) { } } +func TestQueryExpectationNamedArgComparison(t *testing.T) { + e := &queryBasedExpectation{} + against := []namedValue{{Value: int64(5), Name: "id"}} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + } + + e.args = []driver.Value{ + namedValue{Name: "id", Value: int64(5)}, + namedValue{Name: "s", Value: "str"}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the size is not the same") + } + + against = []namedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "s"}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } + + against = []namedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "username"}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to Name") + } + + e.args = []driver.Value{ + namedValue{Ordinal: 1, Value: int64(5)}, + namedValue{Ordinal: 2, Value: "str"}, + } + + against = []namedValue{ + {Value: int64(5), Ordinal: 0}, + {Value: "str", Ordinal: 1}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to wrong Ordinal position") + } + + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } +} + func TestQueryExpectationArgComparisonBool(t *testing.T) { var e *queryBasedExpectation e = &queryBasedExpectation{args: []driver.Value{true}} - against := []driver.Value{true} + against := []namedValue{ + {Value: true, Ordinal: 1}, + } if err := e.argsMatches(against); err != nil { t.Error("arguments should match, since arguments are the same") } e = &queryBasedExpectation{args: []driver.Value{false}} - against = []driver.Value{false} + against = []namedValue{ + {Value: false, Ordinal: 1}, + } if err := e.argsMatches(against); err != nil { t.Error("arguments should match, since argument are the same") } e = &queryBasedExpectation{args: []driver.Value{true}} - against = []driver.Value{false} + against = []namedValue{ + {Value: false, Ordinal: 1}, + } if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since argument is different") } e = &queryBasedExpectation{args: []driver.Value{false}} - against = []driver.Value{true} + against = []namedValue{ + {Value: true, Ordinal: 1}, + } if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since argument is different") } @@ -117,7 +195,7 @@ func TestBuildQuery(t *testing.T) { name = 'John' and address = 'Jakarta' - + ` mock.ExpectQuery(query) diff --git a/sqlmock.go b/sqlmock.go index 7ac8076..500b5c6 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -15,6 +15,7 @@ import ( "database/sql/driver" "fmt" "regexp" + "time" ) // Sqlmock interface serves to create expectations @@ -184,6 +185,7 @@ func (c *sqlmock) Begin() (driver.Tx, error) { expected.triggered = true expected.Unlock() + defer time.Sleep(expected.delay) return c, expected.err } @@ -194,7 +196,18 @@ func (c *sqlmock) ExpectBegin() *ExpectedBegin { } // Exec meets http://golang.org/pkg/database/sql/driver/#Execer -func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) { +func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + return c.exec(nil, query, namedArgs) +} + +func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res driver.Result, err error) { query = stripQuery(query) var expected *ExpectedExec var fulfilled int @@ -230,17 +243,19 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er return nil, fmt.Errorf(msg, query, args) } - defer expected.Unlock() - if !expected.queryMatches(query) { + expected.Unlock() return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String()) } if err := expected.argsMatches(args); err != nil { + expected.Unlock() return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err) } expected.triggered = true + defer time.Sleep(expected.delay) + defer expected.Unlock() if expected.err != nil { return nil, expected.err // mocked to return error @@ -292,12 +307,14 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { } return nil, fmt.Errorf(msg, query) } - defer expected.Unlock() if !expected.sqlRegex.MatchString(query) { + expected.Unlock() return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) } expected.triggered = true + defer time.Sleep(expected.delay) + defer expected.Unlock() return &statement{c, query, expected.closeErr}, expected.err } @@ -308,8 +325,27 @@ func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { return e } +type namedValue struct { + Name string + Ordinal int + Value driver.Value +} + // Query meets http://golang.org/pkg/database/sql/driver/#Queryer func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + return c.query(nil, query, namedArgs) +} + +// in order to prevent dependencies, we use Context as a plain interface +// since it is only related to internal implementation +func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw driver.Rows, err error) { query = stripQuery(query) var expected *ExpectedQuery var fulfilled int @@ -346,18 +382,21 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err return nil, fmt.Errorf(msg, query, args) } - defer expected.Unlock() - if !expected.queryMatches(query) { + expected.Unlock() return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) } if err := expected.argsMatches(args); err != nil { + expected.Unlock() return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err) } expected.triggered = true + defer time.Sleep(expected.delay) + defer expected.Unlock() + if expected.err != nil { return nil, expected.err // mocked to return error } diff --git a/sqlmock_go18.go b/sqlmock_go18.go new file mode 100644 index 0000000..2021007 --- /dev/null +++ b/sqlmock_go18.go @@ -0,0 +1,5 @@ +// +build go1.8 + +package sqlmock + +// @TODO context based extensions From 965003de80e9f030b815e69266c52ca32d2408b9 Mon Sep 17 00:00:00 2001 From: gedi Date: Tue, 7 Feb 2017 12:20:08 +0200 Subject: [PATCH 2/9] implements Context based sql driver extensions --- sqlmock.go | 88 +++++++++++++++++++-------- sqlmock_go18.go | 140 ++++++++++++++++++++++++++++++++++++++++++- sqlmock_go18_test.go | 3 + 3 files changed, 205 insertions(+), 26 deletions(-) create mode 100644 sqlmock_go18_test.go diff --git a/sqlmock.go b/sqlmock.go index 500b5c6..536fa13 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -155,6 +155,20 @@ func (c *sqlmock) ExpectationsWereMet() error { // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Begin() (driver.Tx, error) { + ex, err := c.beginExpectation() + if err != nil { + return nil, err + } + + return c.begin(ex) +} + +func (c *sqlmock) begin(expected *ExpectedBegin) (driver.Tx, error) { + defer time.Sleep(expected.delay) + return c, nil +} + +func (c *sqlmock) beginExpectation() (*ExpectedBegin, error) { var expected *ExpectedBegin var ok bool var fulfilled int @@ -185,8 +199,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) { expected.triggered = true expected.Unlock() - defer time.Sleep(expected.delay) - return c, expected.err + + return expected, expected.err } func (c *sqlmock) ExpectBegin() *ExpectedBegin { @@ -204,10 +218,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) Value: v, } } - return c.exec(nil, query, namedArgs) + + ex, err := c.execExpectation(query, namedArgs) + if err != nil { + return nil, err + } + + return c.exec(ex) } -func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res driver.Result, err error) { +func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExec, error) { query = stripQuery(query) var expected *ExpectedExec var fulfilled int @@ -242,21 +262,17 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr } return nil, fmt.Errorf(msg, query, args) } + defer expected.Unlock() if !expected.queryMatches(query) { - expected.Unlock() return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String()) } if err := expected.argsMatches(args); err != nil { - expected.Unlock() return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err) } expected.triggered = true - defer time.Sleep(expected.delay) - defer expected.Unlock() - if expected.err != nil { return nil, expected.err // mocked to return error } @@ -265,7 +281,12 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected) } - return expected.result, err + return expected, nil +} + +func (c *sqlmock) exec(expected *ExpectedExec) (driver.Result, error) { + defer time.Sleep(expected.delay) + return expected.result, nil } func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { @@ -278,6 +299,15 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { + ex, err := c.prepareExpectation(query) + if err != nil { + return nil, err + } + + return c.prepare(ex, query) +} + +func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) { var expected *ExpectedPrepare var fulfilled int var ok bool @@ -307,15 +337,18 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { } return nil, fmt.Errorf(msg, query) } + defer expected.Unlock() if !expected.sqlRegex.MatchString(query) { - expected.Unlock() return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) } expected.triggered = true + return expected, expected.err +} + +func (c *sqlmock) prepare(expected *ExpectedPrepare, query string) (driver.Stmt, error) { defer time.Sleep(expected.delay) - defer expected.Unlock() - return &statement{c, query, expected.closeErr}, expected.err + return &statement{c, query, expected.closeErr}, nil } func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { @@ -332,7 +365,7 @@ type namedValue struct { } // Query meets http://golang.org/pkg/database/sql/driver/#Queryer -func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) { +func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { namedArgs := make([]namedValue, len(args)) for i, v := range args { namedArgs[i] = namedValue{ @@ -340,12 +373,16 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err Value: v, } } - return c.query(nil, query, namedArgs) + + ex, err := c.queryExpectation(query, namedArgs) + if err != nil { + return nil, err + } + + return c.query(ex) } -// in order to prevent dependencies, we use Context as a plain interface -// since it is only related to internal implementation -func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw driver.Rows, err error) { +func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQuery, error) { query = stripQuery(query) var expected *ExpectedQuery var fulfilled int @@ -382,21 +419,17 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr return nil, fmt.Errorf(msg, query, args) } + defer expected.Unlock() + if !expected.queryMatches(query) { - expected.Unlock() return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) } if err := expected.argsMatches(args); err != nil { - expected.Unlock() return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err) } expected.triggered = true - - defer time.Sleep(expected.delay) - defer expected.Unlock() - if expected.err != nil { return nil, expected.err // mocked to return error } @@ -404,8 +437,13 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr if expected.rows == nil { return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected) } + return expected, nil +} + +func (c *sqlmock) query(expected *ExpectedQuery) (driver.Rows, error) { + defer time.Sleep(expected.delay) - return expected.rows, err + return expected.rows, nil } func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { diff --git a/sqlmock_go18.go b/sqlmock_go18.go index 2021007..0c82a3b 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -2,4 +2,142 @@ package sqlmock -// @TODO context based extensions +import ( + "context" + "database/sql/driver" + "fmt" +) + +var CancelledStatementErr = fmt.Errorf("canceling query due to user request") + +// Implement the "QueryerContext" interface +func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + namedArgs := make([]namedValue, len(args)) + for i, nv := range args { + namedArgs[i] = namedValue(nv) + } + + ex, err := c.queryExpectation(query, namedArgs) + if err != nil { + return nil, err + } + + type result struct { + rows driver.Rows + err error + } + + exec := make(chan result) + defer func() { + close(exec) + }() + + go func() { + rows, err := c.query(ex) + exec <- result{rows, err} + }() + + select { + case res := <-exec: + return res.rows, res.err + case <-ctx.Done(): + return nil, CancelledStatementErr + } +} + +// Implement the "ExecerContext" interface +func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + namedArgs := make([]namedValue, len(args)) + for i, nv := range args { + namedArgs[i] = namedValue(nv) + } + + ex, err := c.execExpectation(query, namedArgs) + if err != nil { + return nil, err + } + + type result struct { + rs driver.Result + err error + } + + exec := make(chan result) + defer func() { + close(exec) + }() + + go func() { + rs, err := c.exec(ex) + exec <- result{rs, err} + }() + + select { + case res := <-exec: + return res.rs, res.err + case <-ctx.Done(): + return nil, CancelledStatementErr + } +} + +// Implement the "ConnBeginTx" interface +func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + ex, err := c.beginExpectation() + if err != nil { + return nil, err + } + + type result struct { + tx driver.Tx + err error + } + + exec := make(chan result) + defer func() { + close(exec) + }() + + go func() { + tx, err := c.begin(ex) + exec <- result{tx, err} + }() + + select { + case res := <-exec: + return res.tx, res.err + case <-ctx.Done(): + return nil, CancelledStatementErr + } +} + +// Implement the "ConnPrepareContext" interface +func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + ex, err := c.prepareExpectation(query) + if err != nil { + return nil, err + } + + type result struct { + stmt driver.Stmt + err error + } + + exec := make(chan result) + defer func() { + close(exec) + }() + + go func() { + stmt, err := c.prepare(ex, query) + exec <- result{stmt, err} + }() + + select { + case res := <-exec: + return res.stmt, res.err + case <-ctx.Done(): + return nil, CancelledStatementErr + } +} + +// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go new file mode 100644 index 0000000..713ed0f --- /dev/null +++ b/sqlmock_go18_test.go @@ -0,0 +1,3 @@ +// +build go1.8 + +package sqlmock From cfb2877c66e3438409cdc4c6fe478c36e56ff274 Mon Sep 17 00:00:00 2001 From: gedi Date: Tue, 7 Feb 2017 15:03:05 +0200 Subject: [PATCH 3/9] tests Context sql driver extensions --- arg_matcher_before_go18.go | 45 +++++ arg_matcher_go18.go | 54 ++++++ expectations.go | 47 ------ expectations_test.go | 58 ------- expectations_test_go18.go | 64 +++++++ sqlmock.go | 47 ++---- sqlmock_go18.go | 97 +++-------- sqlmock_go18_test.go | 332 +++++++++++++++++++++++++++++++++++++ 8 files changed, 529 insertions(+), 215 deletions(-) create mode 100644 arg_matcher_before_go18.go create mode 100644 arg_matcher_go18.go create mode 100644 expectations_test_go18.go diff --git a/arg_matcher_before_go18.go b/arg_matcher_before_go18.go new file mode 100644 index 0000000..52eb369 --- /dev/null +++ b/arg_matcher_before_go18.go @@ -0,0 +1,45 @@ +// +build !go1.8 + +package sqlmock + +import ( + "database/sql/driver" + "fmt" + "reflect" +) + +func (e *queryBasedExpectation) argsMatches(args []namedValue) error { + if nil == e.args { + return nil + } + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + for k, v := range args { + // custom argument matcher + matcher, ok := e.args[k].(Argument) + if ok { + // @TODO: does it make sense to pass value instead of named value? + if !matcher.Match(v.Value) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } + + dval := e.args[k] + // convert to driver converter + darg, err := driver.DefaultParameterConverter.ConvertValue(dval) + if err != nil { + return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) + } + + if !driver.IsValue(darg) { + return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) + } + + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) + } + } + return nil +} diff --git a/arg_matcher_go18.go b/arg_matcher_go18.go new file mode 100644 index 0000000..610eac3 --- /dev/null +++ b/arg_matcher_go18.go @@ -0,0 +1,54 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" +) + +func (e *queryBasedExpectation) argsMatches(args []namedValue) error { + if nil == e.args { + return nil + } + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + // @TODO should we assert either all args are named or ordinal? + for k, v := range args { + // custom argument matcher + matcher, ok := e.args[k].(Argument) + if ok { + // @TODO: does it make sense to pass value instead of named value? + if !matcher.Match(v.Value) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } + + dval := e.args[k] + if named, isNamed := dval.(sql.NamedArg); isNamed { + dval = named.Value + if v.Name != named.Name { + return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) + } + } + + // convert to driver converter + darg, err := driver.DefaultParameterConverter.ConvertValue(dval) + if err != nil { + return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) + } + + if !driver.IsValue(darg) { + return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) + } + + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) + } + } + return nil +} diff --git a/expectations.go b/expectations.go index 3902ff3..b19fbc9 100644 --- a/expectations.go +++ b/expectations.go @@ -3,7 +3,6 @@ package sqlmock import ( "database/sql/driver" "fmt" - "reflect" "regexp" "strings" "sync" @@ -355,49 +354,3 @@ func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err func (e *queryBasedExpectation) queryMatches(sql string) bool { return e.sqlRegex.MatchString(sql) } - -func (e *queryBasedExpectation) argsMatches(args []namedValue) error { - if nil == e.args { - return nil - } - if len(args) != len(e.args) { - return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) - } - for k, v := range args { - // custom argument matcher - matcher, ok := e.args[k].(Argument) - if ok { - // @TODO: does it make sense to pass value instead of named value? - if !matcher.Match(v.Value) { - return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) - } - continue - } - - dval := e.args[k] - if named, isNamed := dval.(namedValue); isNamed { - dval = named.Value - if v.Name != named.Name { - return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) - } - if v.Ordinal != named.Ordinal { - return fmt.Errorf("named argument %d: ordinal position: \"%d\" does not match expected: \"%d\"", k, v.Ordinal, named.Ordinal) - } - } - - // convert to driver converter - darg, err := driver.DefaultParameterConverter.ConvertValue(dval) - if err != nil { - return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) - } - - if !driver.IsValue(darg) { - return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) - } - - if !reflect.DeepEqual(darg, v.Value) { - return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) - } - } - return nil -} diff --git a/expectations_test.go b/expectations_test.go index 6238532..2e3c097 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -64,64 +64,6 @@ func TestQueryExpectationArgComparison(t *testing.T) { } } -func TestQueryExpectationNamedArgComparison(t *testing.T) { - e := &queryBasedExpectation{} - against := []namedValue{{Value: int64(5), Name: "id"}} - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) - } - - e.args = []driver.Value{ - namedValue{Name: "id", Value: int64(5)}, - namedValue{Name: "s", Value: "str"}, - } - - if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since the size is not the same") - } - - against = []namedValue{ - {Value: int64(5), Name: "id"}, - {Value: "str", Name: "s"}, - } - - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should have matched, but it did not: %v", err) - } - - against = []namedValue{ - {Value: int64(5), Name: "id"}, - {Value: "str", Name: "username"}, - } - - if err := e.argsMatches(against); err == nil { - t.Error("arguments matched, but it should have not due to Name") - } - - e.args = []driver.Value{ - namedValue{Ordinal: 1, Value: int64(5)}, - namedValue{Ordinal: 2, Value: "str"}, - } - - against = []namedValue{ - {Value: int64(5), Ordinal: 0}, - {Value: "str", Ordinal: 1}, - } - - if err := e.argsMatches(against); err == nil { - t.Error("arguments matched, but it should have not due to wrong Ordinal position") - } - - against = []namedValue{ - {Value: int64(5), Ordinal: 1}, - {Value: "str", Ordinal: 2}, - } - - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should have matched, but it did not: %v", err) - } -} - func TestQueryExpectationArgComparisonBool(t *testing.T) { var e *queryBasedExpectation diff --git a/expectations_test_go18.go b/expectations_test_go18.go new file mode 100644 index 0000000..5f30d2f --- /dev/null +++ b/expectations_test_go18.go @@ -0,0 +1,64 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "testing" +) + +func TestQueryExpectationNamedArgComparison(t *testing.T) { + e := &queryBasedExpectation{} + against := []namedValue{{Value: int64(5), Name: "id"}} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + } + + e.args = []driver.Value{ + sql.Named("id", 5), + sql.Named("s", "str"), + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the size is not the same") + } + + against = []namedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "s"}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } + + against = []namedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "username"}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to Name") + } + + e.args = []driver.Value{int64(5), "str"} + + against = []namedValue{ + {Value: int64(5), Ordinal: 0}, + {Value: "str", Ordinal: 1}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to wrong Ordinal position") + } + + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } +} diff --git a/sqlmock.go b/sqlmock.go index 536fa13..2052174 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -155,20 +155,16 @@ func (c *sqlmock) ExpectationsWereMet() error { // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Begin() (driver.Tx, error) { - ex, err := c.beginExpectation() + ex, err := c.begin() if err != nil { return nil, err } - return c.begin(ex) -} - -func (c *sqlmock) begin(expected *ExpectedBegin) (driver.Tx, error) { - defer time.Sleep(expected.delay) + time.Sleep(ex.delay) return c, nil } -func (c *sqlmock) beginExpectation() (*ExpectedBegin, error) { +func (c *sqlmock) begin() (*ExpectedBegin, error) { var expected *ExpectedBegin var ok bool var fulfilled int @@ -219,15 +215,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) } } - ex, err := c.execExpectation(query, namedArgs) + ex, err := c.exec(query, namedArgs) if err != nil { return nil, err } - return c.exec(ex) + time.Sleep(ex.delay) + return ex.result, nil } -func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExec, error) { +func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { query = stripQuery(query) var expected *ExpectedExec var fulfilled int @@ -284,11 +281,6 @@ func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExe return expected, nil } -func (c *sqlmock) exec(expected *ExpectedExec) (driver.Result, error) { - defer time.Sleep(expected.delay) - return expected.result, nil -} - func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { e := &ExpectedExec{} sqlRegexStr = stripQuery(sqlRegexStr) @@ -299,15 +291,16 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { - ex, err := c.prepareExpectation(query) + ex, err := c.prepare(query) if err != nil { return nil, err } - return c.prepare(ex, query) + time.Sleep(ex.delay) + return &statement{c, query, ex.closeErr}, nil } -func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) { +func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { var expected *ExpectedPrepare var fulfilled int var ok bool @@ -346,11 +339,6 @@ func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) { return expected, expected.err } -func (c *sqlmock) prepare(expected *ExpectedPrepare, query string) (driver.Stmt, error) { - defer time.Sleep(expected.delay) - return &statement{c, query, expected.closeErr}, nil -} - func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { sqlRegexStr = stripQuery(sqlRegexStr) e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c} @@ -374,15 +362,16 @@ func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) } } - ex, err := c.queryExpectation(query, namedArgs) + ex, err := c.query(query, namedArgs) if err != nil { return nil, err } - return c.query(ex) + time.Sleep(ex.delay) + return ex.rows, nil } -func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQuery, error) { +func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { query = stripQuery(query) var expected *ExpectedQuery var fulfilled int @@ -440,12 +429,6 @@ func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQu return expected, nil } -func (c *sqlmock) query(expected *ExpectedQuery) (driver.Rows, error) { - defer time.Sleep(expected.delay) - - return expected.rows, nil -} - func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { e := &ExpectedQuery{} sqlRegexStr = stripQuery(sqlRegexStr) diff --git a/sqlmock_go18.go b/sqlmock_go18.go index 0c82a3b..7b3c949 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -5,10 +5,11 @@ package sqlmock import ( "context" "database/sql/driver" - "fmt" + "errors" + "time" ) -var CancelledStatementErr = fmt.Errorf("canceling query due to user request") +var ErrCancelled = errors.New("canceling query due to user request") // Implement the "QueryerContext" interface func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { @@ -17,31 +18,16 @@ func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver. namedArgs[i] = namedValue(nv) } - ex, err := c.queryExpectation(query, namedArgs) + ex, err := c.query(query, namedArgs) if err != nil { return nil, err } - type result struct { - rows driver.Rows - err error - } - - exec := make(chan result) - defer func() { - close(exec) - }() - - go func() { - rows, err := c.query(ex) - exec <- result{rows, err} - }() - select { - case res := <-exec: - return res.rows, res.err + case <-time.After(ex.delay): + return ex.rows, nil case <-ctx.Done(): - return nil, CancelledStatementErr + return nil, ErrCancelled } } @@ -52,91 +38,46 @@ func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.N namedArgs[i] = namedValue(nv) } - ex, err := c.execExpectation(query, namedArgs) + ex, err := c.exec(query, namedArgs) if err != nil { return nil, err } - type result struct { - rs driver.Result - err error - } - - exec := make(chan result) - defer func() { - close(exec) - }() - - go func() { - rs, err := c.exec(ex) - exec <- result{rs, err} - }() - select { - case res := <-exec: - return res.rs, res.err + case <-time.After(ex.delay): + return ex.result, nil case <-ctx.Done(): - return nil, CancelledStatementErr + return nil, ErrCancelled } } // Implement the "ConnBeginTx" interface func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - ex, err := c.beginExpectation() + ex, err := c.begin() if err != nil { return nil, err } - type result struct { - tx driver.Tx - err error - } - - exec := make(chan result) - defer func() { - close(exec) - }() - - go func() { - tx, err := c.begin(ex) - exec <- result{tx, err} - }() - select { - case res := <-exec: - return res.tx, res.err + case <-time.After(ex.delay): + return c, nil case <-ctx.Done(): - return nil, CancelledStatementErr + return nil, ErrCancelled } } // Implement the "ConnPrepareContext" interface func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - ex, err := c.prepareExpectation(query) + ex, err := c.prepare(query) if err != nil { return nil, err } - type result struct { - stmt driver.Stmt - err error - } - - exec := make(chan result) - defer func() { - close(exec) - }() - - go func() { - stmt, err := c.prepare(ex, query) - exec <- result{stmt, err} - }() - select { - case res := <-exec: - return res.stmt, res.err + case <-time.After(ex.delay): + return &statement{c, query, ex.closeErr}, nil case <-ctx.Done(): - return nil, CancelledStatementErr + return nil, ErrCancelled } } diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go index 713ed0f..e491fbd 100644 --- a/sqlmock_go18_test.go +++ b/sqlmock_go18_test.go @@ -1,3 +1,335 @@ // +build go1.8 package sqlmock + +import ( + "context" + "database/sql" + "testing" + "time" +) + +func TestContextExecCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.ExecContext(ctx, "DELETE FROM users") + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.ExecContext(ctx, "DELETE FROM users") + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextExecWithNamedArg(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WithArgs(sql.Named("id", 5)). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5)) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5)) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextExec(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + res, err := db.ExecContext(ctx, "DELETE FROM users") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + affected, err := res.RowsAffected() + if affected != 1 { + t.Errorf("expected affected rows 1, but got %v", affected) + } + + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextQueryCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillDelayFor(time.Second). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextQuery(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id ="). + WithArgs(sql.Named("id", 5)). + WillDelayFor(time.Millisecond * 3). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + rows, err := db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = :id", sql.Named("id", 5)) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if !rows.Next() { + t.Error("expected one row, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextBeginCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin().WillDelayFor(time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.BeginTx(ctx, nil) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.BeginTx(ctx, nil) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextBegin(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin().WillDelayFor(time.Millisecond * 3) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if tx == nil { + t.Error("expected tx, but there was nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextPrepareCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT").WillDelayFor(time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.PrepareContext(ctx, "SELECT") + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.PrepareContext(ctx, "SELECT") + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextPrepare(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT").WillDelayFor(time.Millisecond * 3) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.PrepareContext(ctx, "SELECT") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if stmt == nil { + t.Error("expected stmt, but there was nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} From 6bbe187a1a2c87b3249bd159d9d2401d190aa6a5 Mon Sep 17 00:00:00 2001 From: gedi Date: Wed, 8 Feb 2017 15:09:40 +0200 Subject: [PATCH 4/9] closes #60 --- sqlmock.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sqlmock.go b/sqlmock.go index 2052174..b906a3f 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -67,6 +67,11 @@ type Sqlmock interface { // By default it is set to - true. But if you use goroutines // to parallelize your query executation, that option may // be handy. + // + // This option may be turned on anytime during tests. As soon + // as it is switched to false, expectations will be matched + // in any order. Or otherwise if switched to true, any unmatched + // expectations will be expected in order MatchExpectationsInOrder(bool) } From 42ab7c33d093896262b1bb933889189082c20bc1 Mon Sep 17 00:00:00 2001 From: gedi Date: Wed, 8 Feb 2017 15:32:39 +0200 Subject: [PATCH 5/9] implements Pinger (without expectation yet) and prepared stmt Context methods --- sqlmock_go18.go | 17 +++++++++ sqlmock_go18_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/sqlmock_go18.go b/sqlmock_go18.go index 7b3c949..c49429c 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -81,4 +81,21 @@ func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt } } +// Implement the "Pinger" interface +// for now we do not have a Ping expectation +// may be something for the future +func (c *sqlmock) Ping(ctx context.Context) error { + return nil +} + +// Implement the "StmtExecContext" interface +func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + return stmt.conn.ExecContext(ctx, stmt.query, args) +} + +// Implement the "StmtQueryContext" interface +func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + return stmt.conn.QueryContext(ctx, stmt.query, args) +} + // @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go index e491fbd..9eadcb5 100644 --- a/sqlmock_go18_test.go +++ b/sqlmock_go18_test.go @@ -47,6 +47,50 @@ func TestContextExecCancel(t *testing.T) { } } +func TestPreparedStatementContextExecCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("DELETE FROM users"). + ExpectExec(). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.Prepare("DELETE FROM users") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + _, err = stmt.ExecContext(ctx) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = stmt.ExecContext(ctx) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + func TestContextExecWithNamedArg(t *testing.T) { t.Parallel() db, mock, err := New() @@ -164,6 +208,53 @@ func TestContextQueryCancel(t *testing.T) { } } +func TestPreparedStatementContextQueryCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?"). + ExpectQuery(). + WithArgs(5). + WillDelayFor(time.Second). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + _, err = stmt.QueryContext(ctx, 5) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = stmt.QueryContext(ctx, 5) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + func TestContextQuery(t *testing.T) { t.Parallel() db, mock, err := New() From 128bf5c539d9ca228e5ae3f2d02770c8ffb6dd91 Mon Sep 17 00:00:00 2001 From: gedi Date: Wed, 8 Feb 2017 17:35:32 +0200 Subject: [PATCH 6/9] implements next rows result set support --- expectations.go | 15 +-- ...ore_go18.go => expectations_before_go18.go | 7 ++ arg_matcher_go18.go => expectations_go18.go | 11 +++ rows.go | 93 +++++++++---------- rows_go18.go | 20 ++++ rows_go18_test.go | 92 ++++++++++++++++++ 6 files changed, 178 insertions(+), 60 deletions(-) rename arg_matcher_before_go18.go => expectations_before_go18.go (84%) rename arg_matcher_go18.go => expectations_go18.go (83%) create mode 100644 rows_go18.go create mode 100644 rows_go18_test.go diff --git a/expectations.go b/expectations.go index b19fbc9..adc726e 100644 --- a/expectations.go +++ b/expectations.go @@ -144,13 +144,6 @@ func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { return e } -// WillReturnRows specifies the set of resulting rows that will be returned -// by the triggered query -func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery { - e.rows = rows - return e -} - // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { @@ -175,9 +168,11 @@ func (e *ExpectedQuery) String() string { if e.rows != nil { msg += "\n - should return rows:\n" - rs, _ := e.rows.(*rows) - for i, row := range rs.rows { - msg += fmt.Sprintf(" %d - %+v\n", i, row) + rs, _ := e.rows.(*rowSets) + for _, set := range rs.sets { + for i, row := range set.rows { + msg += fmt.Sprintf(" %d - %+v\n", i, row) + } } msg = strings.TrimSpace(msg) } diff --git a/arg_matcher_before_go18.go b/expectations_before_go18.go similarity index 84% rename from arg_matcher_before_go18.go rename to expectations_before_go18.go index 52eb369..146f240 100644 --- a/arg_matcher_before_go18.go +++ b/expectations_before_go18.go @@ -8,6 +8,13 @@ import ( "reflect" ) +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery { + e.rows = &rowSets{sets: []*Rows{rows}} + return e +} + func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { return nil diff --git a/arg_matcher_go18.go b/expectations_go18.go similarity index 83% rename from arg_matcher_go18.go rename to expectations_go18.go index 610eac3..29eeb30 100644 --- a/arg_matcher_go18.go +++ b/expectations_go18.go @@ -9,6 +9,17 @@ import ( "reflect" ) +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { + sets := make([]*Rows, len(rows)) + for i, r := range rows { + sets[i] = r + } + e.rows = &rowSets{sets: sets} + return e +} + func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { return nil diff --git a/rows.go b/rows.go index 8b6beb6..43681d4 100644 --- a/rows.go +++ b/rows.go @@ -18,57 +18,22 @@ var CSVColumnParser = func(s string) []byte { return []byte(s) } -// Rows interface allows to construct rows -// which also satisfies database/sql/driver.Rows interface -type Rows interface { - // composed interface, supports sql driver.Rows - driver.Rows - - // AddRow composed from database driver.Value slice - // return the same instance to perform subsequent actions. - // Note that the number of values must match the number - // of columns - AddRow(columns ...driver.Value) Rows - - // FromCSVString build rows from csv string. - // return the same instance to perform subsequent actions. - // Note that the number of values must match the number - // of columns - FromCSVString(s string) Rows - - // RowError allows to set an error - // which will be returned when a given - // row number is read - RowError(row int, err error) Rows - - // CloseError allows to set an error - // which will be returned by rows.Close - // function. - // - // The close error will be triggered only in cases - // when rows.Next() EOF was not yet reached, that is - // a default sql library behavior - CloseError(err error) Rows +type rowSets struct { + sets []*Rows + pos int } -type rows struct { - cols []string - rows [][]driver.Value - pos int - nextErr map[int]error - closeErr error -} - -func (r *rows) Columns() []string { - return r.cols +func (rs *rowSets) Columns() []string { + return rs.sets[rs.pos].cols } -func (r *rows) Close() error { - return r.closeErr +func (rs *rowSets) Close() error { + return rs.sets[rs.pos].closeErr } // advances to next row -func (r *rows) Next(dest []driver.Value) error { +func (rs *rowSets) Next(dest []driver.Value) error { + r := rs.sets[rs.pos] r.pos++ if r.pos > len(r.rows) { return io.EOF // per interface spec @@ -81,24 +46,48 @@ func (r *rows) Next(dest []driver.Value) error { return r.nextErr[r.pos-1] } +// Rows is a mocked collection of rows to +// return for Query result +type Rows struct { + cols []string + rows [][]driver.Value + pos int + nextErr map[int]error + closeErr error +} + // NewRows allows Rows to be created from a // sql driver.Value slice or from the CSV string and // to be used as sql driver.Rows -func NewRows(columns []string) Rows { - return &rows{cols: columns, nextErr: make(map[int]error)} +func NewRows(columns []string) *Rows { + return &Rows{cols: columns, nextErr: make(map[int]error)} } -func (r *rows) CloseError(err error) Rows { +// CloseError allows to set an error +// which will be returned by rows.Close +// function. +// +// The close error will be triggered only in cases +// when rows.Next() EOF was not yet reached, that is +// a default sql library behavior +func (r *Rows) CloseError(err error) *Rows { r.closeErr = err return r } -func (r *rows) RowError(row int, err error) Rows { +// RowError allows to set an error +// which will be returned when a given +// row number is read +func (r *Rows) RowError(row int, err error) *Rows { r.nextErr[row] = err return r } -func (r *rows) AddRow(values ...driver.Value) Rows { +// AddRow composed from database driver.Value slice +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) AddRow(values ...driver.Value) *Rows { if len(values) != len(r.cols) { panic("Expected number of values to match number of columns") } @@ -112,7 +101,11 @@ func (r *rows) AddRow(values ...driver.Value) Rows { return r } -func (r *rows) FromCSVString(s string) Rows { +// FromCSVString build rows from csv string. +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) FromCSVString(s string) *Rows { res := strings.NewReader(strings.TrimSpace(s)) csvReader := csv.NewReader(res) diff --git a/rows_go18.go b/rows_go18.go new file mode 100644 index 0000000..4ecf84e --- /dev/null +++ b/rows_go18.go @@ -0,0 +1,20 @@ +// +build go1.8 + +package sqlmock + +import "io" + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) HasNextResultSet() bool { + return rs.pos+1 < len(rs.sets) +} + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) NextResultSet() error { + if !rs.HasNextResultSet() { + return io.EOF + } + + rs.pos++ + return nil +} diff --git a/rows_go18_test.go b/rows_go18_test.go new file mode 100644 index 0000000..297e7c0 --- /dev/null +++ b/rows_go18_test.go @@ -0,0 +1,92 @@ +// +build go1.8 + +package sqlmock + +import ( + "fmt" + "testing" +) + +func TestQueryMultiRows(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error")) + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users"). + WithArgs(5). + WillReturnRows(rs1, rs2) + + rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + defer rows.Close() + + if !rows.Next() { + t.Error("expected a row to be available in first result set") + } + + var id int + var name string + + err = rows.Scan(&id, &name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if id != 5 || name != "hello world" { + t.Errorf("unexpected row values id: %v name: %v", id, name) + } + + if rows.Next() { + t.Error("was not expecting next row in first result set") + } + + if !rows.NextResultSet() { + t.Error("had to have next result set") + } + + if !rows.Next() { + t.Error("expected a row to be available in second result set") + } + + err = rows.Scan(&name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if name != "gopher" { + t.Errorf("unexpected row name: %v", name) + } + + if !rows.Next() { + t.Error("expected a row to be available in second result set") + } + + err = rows.Scan(&name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if name != "john" { + t.Errorf("unexpected row name: %v", name) + } + + if rows.Next() { + t.Error("expected next row to produce error") + } + + if rows.Err() == nil { + t.Error("expected an error, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} From 53b2cd1534e9f8f8830498debd1e49b63a08ffef Mon Sep 17 00:00:00 2001 From: gedi Date: Thu, 9 Feb 2017 09:26:25 +0200 Subject: [PATCH 7/9] updates readme and license --- LICENSE | 2 +- README.md | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/LICENSE b/LICENSE index 25255bb..7f8bedf 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) -Copyright (c) 2013-2016, DATA-DOG team +Copyright (c) 2013-2017, DATA-DOG team All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index 80b5bc0..f3d5487 100644 --- a/README.md +++ b/README.md @@ -10,19 +10,16 @@ maintain correct **TDD** workflow. - this library is now complete and stable. (you may not find new changes for this reason) - supports concurrency and multiple connections. +- supports **go1.8** Context related feature mocking and Named sql parameters. - does not require any modifications to your source code. - the driver allows to mock any sql driver method behavior. - has strict by default expectation order matching. -- has no vendor dependencies. +- has no third party dependencies. ## Install go get gopkg.in/DATA-DOG/go-sqlmock.v1 -If you need an old version, checkout **go-sqlmock** at gopkg.in: - - go get gopkg.in/DATA-DOG/go-sqlmock.v0 - ## Documentation and Examples Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference. @@ -187,8 +184,11 @@ It only asserts that argument is of `time.Time` type. go test -race -## Changes +## Change Log +- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct + but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now + accept multiple row sets. - **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL query. It should still be validated even if Exec or Query is not executed on that prepared statement. From a00b6aa80e8d8204aa8aeb9e3b61d94f98973f91 Mon Sep 17 00:00:00 2001 From: gedi Date: Thu, 16 Feb 2017 22:33:12 +0200 Subject: [PATCH 8/9] asserts ordinal argument position, fixes expected query error message --- expectations.go | 9 +-------- expectations_go18.go | 3 ++- ..._test_go18.go => expectations_go18_test.go | 0 rows.go | 19 +++++++++++++++++++ 4 files changed, 22 insertions(+), 9 deletions(-) rename expectations_test_go18.go => expectations_go18_test.go (100%) diff --git a/expectations.go b/expectations.go index adc726e..415759e 100644 --- a/expectations.go +++ b/expectations.go @@ -167,14 +167,7 @@ func (e *ExpectedQuery) String() string { } if e.rows != nil { - msg += "\n - should return rows:\n" - rs, _ := e.rows.(*rowSets) - for _, set := range rs.sets { - for i, row := range set.rows { - msg += fmt.Sprintf(" %d - %+v\n", i, row) - } - } - msg = strings.TrimSpace(msg) + msg += fmt.Sprintf("\n - %s", e.rows) } if e.err != nil { diff --git a/expectations_go18.go b/expectations_go18.go index 29eeb30..2b4b44e 100644 --- a/expectations_go18.go +++ b/expectations_go18.go @@ -32,7 +32,6 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error { // custom argument matcher matcher, ok := e.args[k].(Argument) if ok { - // @TODO: does it make sense to pass value instead of named value? if !matcher.Match(v.Value) { return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) } @@ -45,6 +44,8 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if v.Name != named.Name { return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) } + } else if k+1 != v.Ordinal { + return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal) } // convert to driver converter diff --git a/expectations_test_go18.go b/expectations_go18_test.go similarity index 100% rename from expectations_test_go18.go rename to expectations_go18_test.go diff --git a/rows.go b/rows.go index 43681d4..39f9f83 100644 --- a/rows.go +++ b/rows.go @@ -3,6 +3,7 @@ package sqlmock import ( "database/sql/driver" "encoding/csv" + "fmt" "io" "strings" ) @@ -46,6 +47,24 @@ func (rs *rowSets) Next(dest []driver.Value) error { return r.nextErr[r.pos-1] } +// transforms to debuggable printable string +func (rs *rowSets) String() string { + msg := "should return rows:\n" + if len(rs.sets) == 1 { + for n, row := range rs.sets[0].rows { + msg += fmt.Sprintf(" row %d - %+v\n", n, row) + } + return strings.TrimSpace(msg) + } + for i, set := range rs.sets { + msg += fmt.Sprintf(" result set: %d\n", i) + for n, row := range set.rows { + msg += fmt.Sprintf(" row %d - %+v\n", n, row) + } + } + return strings.TrimSpace(msg) +} + // Rows is a mocked collection of rows to // return for Query result type Rows struct { From 372a183d52c500c7299f746e62df4237416f96e9 Mon Sep 17 00:00:00 2001 From: gedi Date: Tue, 21 Feb 2017 17:53:38 +0200 Subject: [PATCH 9/9] adds go1.8 to travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 31f9c88..211edd7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ go: - 1.5 - 1.6 - 1.7 + - 1.8 - tip script: go test -race -coverprofile=coverage.txt -covermode=atomic