diff --git a/config.go b/config.go index f0e34e82..c85a476a 100644 --- a/config.go +++ b/config.go @@ -28,6 +28,11 @@ import ( // Settings defines methods to get or set configuration values. type Settings interface { + // SetRetryQueryOnError enables or disable query retry-on-error features. + SetRetryQueryOnError(bool) + // Returns true if query retry is enabled. + RetryQueryOnError() bool + // SetLogging enables or disables logging. SetLogging(bool) // LoggingEnabled returns true if logging is enabled, false otherwise. @@ -45,6 +50,8 @@ type conf struct { queryLogger Logger queryLoggerMu sync.RWMutex defaultLogger defaultLogger + + queryRetryOnError uint32 } func (c *conf) Logger() Logger { @@ -58,6 +65,14 @@ func (c *conf) Logger() Logger { return c.queryLogger } +func (c *conf) SetRetryQueryOnError(v bool) { + c.setBinaryOption(&c.queryRetryOnError, v) +} + +func (c *conf) RetryQueryOnError() bool { + return c.binaryOption(&c.queryRetryOnError) +} + func (c *conf) SetLogger(lg Logger) { c.queryLoggerMu.Lock() defer c.queryLoggerMu.Unlock() @@ -65,20 +80,28 @@ func (c *conf) SetLogger(lg Logger) { c.queryLogger = lg } -func (c *conf) SetLogging(value bool) { +func (c *conf) binaryOption(dest *uint32) bool { + if v := atomic.LoadUint32(dest); v == 1 { + return true + } + return false +} + +func (c *conf) setBinaryOption(dest *uint32, value bool) { if value { - atomic.StoreUint32(&c.loggingEnabled, 1) + atomic.StoreUint32(dest, 1) return } - atomic.StoreUint32(&c.loggingEnabled, 0) + atomic.StoreUint32(dest, 0) +} + +func (c *conf) SetLogging(value bool) { + c.setBinaryOption(&c.loggingEnabled, value) } func (c *conf) LoggingEnabled() bool { - if v := atomic.LoadUint32(&c.loggingEnabled); v == 1 { - return true - } - return false + return c.binaryOption(&c.loggingEnabled) } -// Conf provides global configuration settings for upper-db. +// Conf provides default global configuration settings for upper-db. var Conf Settings = &conf{} diff --git a/errors.go b/errors.go index e6b16011..4314c91b 100644 --- a/errors.go +++ b/errors.go @@ -27,26 +27,28 @@ import ( // Error messages. var ( - ErrNoMoreRows = errors.New(`upper: no more rows in this result set`) - ErrNotConnected = errors.New(`upper: you're currently not connected`) - ErrMissingDatabaseName = errors.New(`upper: missing database name`) - ErrMissingCollectionName = errors.New(`upper: missing collection name`) - ErrCollectionDoesNotExist = errors.New(`upper: collection does not exist`) - ErrSockerOrHost = errors.New(`upper: you may connect either to a unix socket or a tcp address, but not both`) - ErrQueryLimitParam = errors.New(`upper: a query can accept only one limit parameter`) - ErrQuerySortParam = errors.New(`upper: a query can accept only one order by parameter`) - ErrQueryOffsetParam = errors.New(`upper: a query can accept only one offset parameter`) - ErrMissingConditions = errors.New(`upper: missing selector conditions`) - ErrUnsupported = errors.New(`upper: this action is currently unsupported on this database`) - ErrUndefined = errors.New(`upper: this value is undefined`) - ErrQueryIsPending = errors.New(`upper: can't execute this instruction while the result set is still open`) - ErrUnsupportedDestination = errors.New(`upper: unsupported destination type`) - ErrUnsupportedType = errors.New(`upper: this type does not support marshaling`) - ErrUnsupportedValue = errors.New(`upper: this value does not support unmarshaling`) - ErrUnknownConditionType = errors.New(`upper: arguments of type %T can't be used as constraints`) - ErrTooManyClients = errors.New(`upper: can't connect to database server: too many clients`) - ErrGivingUpTryingToConnect = errors.New(`upper: giving up trying to connect: too many clients`) - ErrMissingConnURL = errors.New(`upper: missing DSN`) - ErrNotImplemented = errors.New(`upper: call not implemented`) - ErrAlreadyWithinTransaction = errors.New(`upper: already within a transaction`) + ErrNoMoreRows = errors.New(`upper: no more rows in this result set`) + ErrNotConnected = errors.New(`upper: you're currently not connected`) + ErrMissingDatabaseName = errors.New(`upper: missing database name`) + ErrMissingCollectionName = errors.New(`upper: missing collection name`) + ErrCollectionDoesNotExist = errors.New(`upper: collection does not exist`) + ErrSockerOrHost = errors.New(`upper: you may connect either to a unix socket or a tcp address, but not both`) + ErrQueryLimitParam = errors.New(`upper: a query can accept only one limit parameter`) + ErrQuerySortParam = errors.New(`upper: a query can accept only one order by parameter`) + ErrQueryOffsetParam = errors.New(`upper: a query can accept only one offset parameter`) + ErrMissingConditions = errors.New(`upper: missing selector conditions`) + ErrUnsupported = errors.New(`upper: this action is currently unsupported on this database`) + ErrUndefined = errors.New(`upper: this value is undefined`) + ErrQueryIsPending = errors.New(`upper: can't execute this instruction while the result set is still open`) + ErrUnsupportedDestination = errors.New(`upper: unsupported destination type`) + ErrUnsupportedType = errors.New(`upper: this type does not support marshaling`) + ErrUnsupportedValue = errors.New(`upper: this value does not support unmarshaling`) + ErrUnknownConditionType = errors.New(`upper: arguments of type %T can't be used as constraints`) + ErrTooManyClients = errors.New(`upper: can't connect to database server: too many clients`) + ErrGivingUpTryingToConnect = errors.New(`upper: giving up trying to connect: too many clients`) + ErrTooManyReconnectionAttempts = errors.New(`upper: too many reconnection attempts`) + ErrMissingConnURL = errors.New(`upper: missing DSN`) + ErrNotImplemented = errors.New(`upper: call not implemented`) + ErrAlreadyWithinTransaction = errors.New(`upper: already within a transaction`) + ErrServerRefusedConnection = errors.New(`upper: database server refused connection`) ) diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index 7c1fa4d0..a7136e6f 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -2,6 +2,9 @@ package sqladapter import ( "database/sql" + "database/sql/driver" + "errors" + "io" "math" "strconv" "sync" @@ -14,13 +17,53 @@ import ( "upper.io/db.v2/lib/sqlbuilder" ) +// A list of errors that mean the server is not working and that we should +// try to connect and retry the query. +var recoverableErrors = []error{ + io.EOF, + driver.ErrBadConn, + db.ErrNotConnected, + db.ErrTooManyClients, + db.ErrServerRefusedConnection, +} + +const ( + // If a query fails with a recoverable error the connection is going to be + // re-estalished and the query can be retried, each retry adds a max wait + // time of maxConnectionRetryTime + maxQueryRetryAttempts = 6 + + // Minimum interval when waiting before trying to reconnect. + minConnectionRetryInterval = time.Millisecond * 100 + + // Maximum interval when waiting before trying to reconnect. + maxConnectionRetryInterval = time.Millisecond * 1500 + + // Maximum time each connection retry attempt can take. + maxConnectionRetryTime = time.Second * 5 + + // Maximum reconnection attempts per session before giving up. + maxReconnectionAttempts uint64 = 4 + + // If this session failed to recover more than + // flushConnectionPoolAfterRecoverAttempts times, assume the entire pool + // is borked and force a clean reconnection. + flushConnectionPoolAfterRecoverAttempts = uint64(maxQueryRetryAttempts / 2) +) + var ( - lastSessID uint64 - lastTxID uint64 + errNothingToRecoverFrom = errors.New("Nothing to recover from") + errUnableToRecover = errors.New("Unable to recover from this error") ) -// HasCleanUp is implemented by structs that have a clean up routine that needs -// to be called before Close(). +var ( + lastSessID uint64 + lastTxID uint64 + lastOperationID uint64 +) + +// HasCleanUp is implemented by structs that have a clean up routine that +// needs to be called before Close(). type HasCleanUp interface { CleanUp() error } @@ -66,7 +109,7 @@ type BaseDatabase interface { Collection(string) db.Collection Driver() interface{} - WaitForConnection(func() error) error + WaitForConnection(func() error, bool) error BindSession(*sql.DB) error Session() *sql.DB @@ -78,29 +121,42 @@ type BaseDatabase interface { // NewBaseDatabase provides a BaseDatabase given a PartialDatabase func NewBaseDatabase(p PartialDatabase) BaseDatabase { d := &database{ - PartialDatabase: p, + PartialDatabase: p, + cachedCollections: cache.NewCache(), cachedStatements: cache.NewCache(), + connFn: defaultConnFn, } return d } +var defaultConnFn = func() error { + return errors.New("No connection function was defined.") +} + // database is the actual implementation of Database and joins methods from // BaseDatabase and PartialDatabase type database struct { PartialDatabase baseTx BaseTx + connectMu sync.Mutex collectionMu sync.Mutex - databaseMu sync.Mutex - name string - sess *sql.DB - sessMu sync.Mutex + connFn func() error + + name string + + sess *sql.DB + sessErr error + sessMu sync.Mutex sessID uint64 txID uint64 + connectAttempts uint64 + recoverAttempts uint64 + cachedStatements *cache.Cache cachedCollections *cache.Cache @@ -111,8 +167,59 @@ var ( _ = db.Database(&database{}) ) +func (d *database) connect(connFn func() error) error { + if connFn == nil { + return errors.New("Missing connect function") + } + + d.connectMu.Lock() + defer d.connectMu.Unlock() + + // Attempt to (re)connect + if atomic.AddUint64(&d.connectAttempts, 1) >= maxReconnectionAttempts { + return db.ErrTooManyReconnectionAttempts + } + + waitTime := minConnectionRetryInterval + + for start, i := time.Now(), 1; time.Now().Sub(start) < maxConnectionRetryTime; i++ { + waitTime = time.Duration(i) * minConnectionRetryInterval + if waitTime > maxConnectionRetryInterval { + waitTime = maxConnectionRetryInterval + } + // Wait a bit until retrying. + if waitTime > time.Duration(0) { + time.Sleep(waitTime) + } + + err := connFn() + if err == nil { + atomic.StoreUint64(&d.connectAttempts, 0) + atomic.StoreUint64(&d.recoverAttempts, 0) + return nil + } + + if !d.isRecoverableError(err) { + return err + } + } + + return db.ErrGivingUpTryingToConnect +} + // Session returns the underlying *sql.DB func (d *database) Session() *sql.DB { + if atomic.LoadUint64(&d.sessID) == 0 { + // This means the session is connecting for the first time, in this case we + // don't block because the session hasn't been returned yet. + return d.sess + } + + // Prevents goroutines from using the session until the connection is + // re-established. + d.connectMu.Lock() + defer d.connectMu.Unlock() + return d.sess } @@ -133,38 +240,107 @@ func (d *database) BindTx(t *sql.Tx) error { // Tx returns a BaseTx, which, if not nil, means that this session is within a // transaction func (d *database) Transaction() BaseTx { + if atomic.LoadUint64(&d.sessID) == 0 { + // This means the session is connecting for the first time, in this case we + // don't block because the session hasn't been returned yet. + return d.baseTx + } + + d.sessMu.Lock() + defer d.sessMu.Unlock() + return d.baseTx } // Name returns the database named func (d *database) Name() string { - d.databaseMu.Lock() - defer d.databaseMu.Unlock() + return d.name +} - if d.name == "" { - d.name, _ = d.PartialDatabase.FindDatabaseName() +func (d *database) getDBName() error { + name, err := d.PartialDatabase.FindDatabaseName() + if err != nil { + return err } - - return d.name + d.name = name + return nil } // BindSession binds a *sql.DB into *database func (d *database) BindSession(sess *sql.DB) error { + if err := sess.Ping(); err != nil { + return err + } + d.sessMu.Lock() + if d.sess != nil { + d.ClearCache() + d.sess.Close() // Close before rebind. + } d.sess = sess d.sessMu.Unlock() - if err := d.Ping(); err != nil { + // Does this session already have a session ID? + if atomic.LoadUint64(&d.sessID) != 0 { + return nil + } + + // Is this connection really working? + if err := d.getDBName(); err != nil { return err } + // Assign an ID if everyting was OK. d.sessID = newSessionID() - name, err := d.PartialDatabase.FindDatabaseName() - if err != nil { - return err + return nil +} + +func (d *database) isRecoverableError(err error) bool { + err = d.PartialDatabase.Err(err) + for i := 0; i < len(recoverableErrors); i++ { + if err == recoverableErrors[i] { + return true + } } + return false +} - d.name = name +// recoverFromErr attempts to reestablish a connection after a temporary error, +// returns nil if the connection was reestablished and the query can be retried. +func (d *database) recoverFromErr(err error) error { + if err == nil { + return errNothingToRecoverFrom + } + + if d.Transaction() != nil { + // Don't even attempt to recover from within a transaction, this is not + // possible. + return errors.New("Can't recover from within a bad transaction.") + } + + if !d.isRecoverableError(err) { + // This is not an error we can recover from. + return errUnableToRecover + } + + reconnect := false + + if err := d.PartialDatabase.Err(d.Ping()); err != nil { // Let's see if database/sql recovered itself. + reconnect = true + } + + if !reconnect && atomic.AddUint64(&d.recoverAttempts, 1) == flushConnectionPoolAfterRecoverAttempts { + // This happens when d.Ping() says everything is OK but queries keep + // failing, that probably means that a high number of connections in the + // pool are in bad state. Since we don't have any way to check the + // connection pool for valid connections we'll force a reconnection (once). + reconnect = true + } + + if reconnect { + // Let's attempt to connect + return d.connect(d.connFn) + } return nil } @@ -172,16 +348,22 @@ func (d *database) BindSession(sess *sql.DB) error { // Ping checks whether a connection to the database is still alive by pinging // it func (d *database) Ping() error { - if d.sess != nil { - return d.sess.Ping() + if sess := d.Session(); sess != nil { + return sess.Ping() } - return nil + if tx := d.Transaction(); tx != nil { + // This is a wrapped transaction, let's assume we had a working + // connection in the first place. + return nil + } + return db.ErrNotConnected } // ClearCache removes all caches. func (d *database) ClearCache() { d.collectionMu.Lock() defer d.collectionMu.Unlock() + d.cachedCollections.Clear() d.cachedStatements.Clear() if d.template != nil { @@ -197,7 +379,7 @@ func (d *database) Close() error { d.baseTx = nil d.sessMu.Unlock() }() - if d.sess != nil { + if sess := d.Session(); sess != nil { if cleaner, ok := d.PartialDatabase.(HasCleanUp); ok { cleaner.CleanUp() } @@ -210,6 +392,7 @@ func (d *database) Close() error { return d.sess.Close() } + // Don't close the parent session if within a transaction. if !tx.Committed() { tx.Rollback() } @@ -224,9 +407,9 @@ func (d *database) Collection(name string) db.Collection { h := cache.String(name) - ccol, ok := d.cachedCollections.ReadRaw(h) + cachedCollection, ok := d.cachedCollections.ReadRaw(h) if ok { - return ccol.(db.Collection) + return cachedCollection.(db.Collection) } col := d.PartialDatabase.NewLocalCollection(name) @@ -235,22 +418,40 @@ func (d *database) Collection(name string) db.Collection { return col } +func (d *database) prepareAndExec(stmt *exql.Statement, args ...interface{}) (string, sql.Result, error) { + p, query, err := d.prepareStatement(stmt) + if err != nil { + return query, nil, err + } + defer p.Close() + + if execer, ok := d.PartialDatabase.(HasStatementExec); ok { + res, err := execer.StatementExec(p.Stmt, args...) + return query, res, err + } + + res, err := p.Exec(args...) + return query, res, err +} + // StatementExec compiles and executes a statement that does not return any // rows. func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) { var query string + queryID := newOperationID() + if db.Conf.LoggingEnabled() { defer func(start time.Time) { - status := db.QueryStatus{ - TxID: d.txID, - SessID: d.sessID, - Query: query, - Args: args, - Err: err, - Start: start, - End: time.Now(), + TxID: d.txID, + SessID: d.sessID, + QueryID: queryID, + Query: query, + Args: args, + Err: err, + Start: start, + End: time.Now(), } if res != nil { @@ -267,47 +468,79 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res }(time.Now()) } - var p *Stmt - if p, query, err = d.prepareStatement(stmt); err != nil { - return nil, err + for i := 0; ; i++ { + query, res, err = d.prepareAndExec(stmt, args...) + if err == nil || i >= maxQueryRetryAttempts || !db.Conf.RetryQueryOnError() { + return res, err + } + + // Try to recover + if recoverErr := d.recoverFromErr(err); recoverErr != nil { + return nil, err // Unable to recover. + } } - defer p.Close() - if execer, ok := d.PartialDatabase.(HasStatementExec); ok { - res, err = execer.StatementExec(p.Stmt, args...) - return + panic("reached") +} + +func (d *database) prepareAndQuery(stmt *exql.Statement, args ...interface{}) (string, *sql.Rows, error) { + p, query, err := d.prepareStatement(stmt) + if err != nil { + return query, nil, err } + defer p.Close() - res, err = p.Exec(args...) - return + rows, err := p.Query(args...) + return query, rows, err } // StatementQuery compiles and executes a statement that returns rows. func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) { var query string + queryID := newOperationID() + if db.Conf.LoggingEnabled() { defer func(start time.Time) { db.Log(&db.QueryStatus{ - TxID: d.txID, - SessID: d.sessID, - Query: query, - Args: args, - Err: err, - Start: start, - End: time.Now(), + TxID: d.txID, + SessID: d.sessID, + QueryID: queryID, + Query: query, + Args: args, + Err: err, + Start: start, + End: time.Now(), }) }(time.Now()) } - var p *Stmt - if p, query, err = d.prepareStatement(stmt); err != nil { - return nil, err + for i := 0; ; i++ { + query, rows, err = d.prepareAndQuery(stmt, args...) + if err == nil || i >= maxQueryRetryAttempts || !db.Conf.RetryQueryOnError() { + return rows, err + } + + // Try to recover + if recoverErr := d.recoverFromErr(err); recoverErr != nil { + return nil, err // Unable to recover. + } + } + + panic("reached") +} + +func (d *database) prepareAndQueryRow(stmt *exql.Statement, args ...interface{}) (string, *sql.Row, error) { + p, query, err := d.prepareStatement(stmt) + if err != nil { + return query, nil, err } defer p.Close() - rows, err = p.Query(args...) - return + // Would be nice to find a way to check if this succeeded before using + // Scan. + rows, err := p.QueryRow(args...), nil + return query, rows, nil } // StatementQueryRow compiles and executes a statement that returns at most one @@ -315,62 +548,70 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) { var query string + queryID := newOperationID() + if db.Conf.LoggingEnabled() { defer func(start time.Time) { db.Log(&db.QueryStatus{ - TxID: d.txID, - SessID: d.sessID, - Query: query, - Args: args, - Err: err, - Start: start, - End: time.Now(), + TxID: d.txID, + SessID: d.sessID, + QueryID: queryID, + Query: query, + Args: args, + Err: err, + Start: start, + End: time.Now(), }) }(time.Now()) } - var p *Stmt - if p, query, err = d.prepareStatement(stmt); err != nil { - return nil, err + for i := 0; ; i++ { + query, row, err = d.prepareAndQueryRow(stmt, args...) + if err == nil || i >= maxQueryRetryAttempts || !db.Conf.RetryQueryOnError() { + return row, err + } + + // Try to recover + if recoverErr := d.recoverFromErr(err); recoverErr != nil { + return nil, err // Unable to recover. + } } - defer p.Close() - row, err = p.QueryRow(args...), nil - return + panic("reached") } // Driver returns the underlying *sql.DB or *sql.Tx instance. func (d *database) Driver() interface{} { if tx := d.Transaction(); tx != nil { - // A transaction return tx.(*sqlTx).Tx } - return d.sess + return d.Session() } // prepareStatement converts a *exql.Statement representation into an actual // *sql.Stmt. This method will attempt to used a cached prepared statement, if // available. func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) { - if d.sess == nil && d.Transaction() == nil { + if d.Session() == nil && d.Transaction() == nil { return nil, "", db.ErrNotConnected } pc, ok := d.cachedStatements.ReadRaw(stmt) if ok { - // The statement was cached. + // This prepared statement was cached, no need to build or to prepare + // again. ps, err := pc.(*Stmt).Open() if err == nil { return ps, ps.query, nil } } - // Plain SQL query. + // Building the actual SQL query. query := d.PartialDatabase.CompileStatement(stmt) sqlStmt, err := func() (*sql.Stmt, error) { - if d.Transaction() != nil { - return d.Transaction().(*sqlTx).Prepare(query) + if tx := d.Transaction(); tx != nil { + return tx.(*sqlTx).Prepare(query) } return d.sess.Prepare(query) }() @@ -385,43 +626,20 @@ func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) var waitForConnMu sync.Mutex -// WaitForConnection tries to execute the given connectFn function, if -// connectFn returns an error, then WaitForConnection will keep trying until -// connectFn returns nil. Maximum waiting time is 5s after having acquired the -// lock. -func (d *database) WaitForConnection(connectFn func() error) error { - // This lock ensures first-come, first-served and prevents opening too many - // file descriptors. - waitForConnMu.Lock() - defer waitForConnMu.Unlock() - - // Minimum waiting time. - waitTime := time.Millisecond * 10 - - // Waitig 5 seconds for a successful connection. - for timeStart := time.Now(); time.Now().Sub(timeStart) < time.Second*5; { - err := connectFn() - if err == nil { - return nil // Connected! - } - - // Only attempt to reconnect if the error is too many clients. - if d.PartialDatabase.Err(err) == db.ErrTooManyClients { - // Sleep and try again if, and only if, the server replied with a "too - // many clients" error. - time.Sleep(waitTime) - if waitTime < time.Millisecond*500 { - // Wait a bit more next time. - waitTime = waitTime * 2 - } - continue - } - - // Return any other error immediately. +// WaitForConnection tries to execute the given connFn function, if connFn +// returns an error, then WaitForConnection will keep trying until connFn +// returns nil. Maximum waiting time is 5s after having acquired the lock. +func (d *database) WaitForConnection(connFn func() error, isDefault bool) error { + if err := d.connect(connFn); err != nil { return err } - return db.ErrGivingUpTryingToConnect + if isDefault { + d.connectMu.Lock() + d.connFn = connFn + d.connectMu.Unlock() + } + return nil } // ReplaceWithDollarSign turns a SQL statament with '?' placeholders into @@ -448,16 +666,24 @@ func ReplaceWithDollarSign(in string) string { func newSessionID() uint64 { if atomic.LoadUint64(&lastSessID) == math.MaxUint64 { - atomic.StoreUint64(&lastSessID, 0) - return 0 + atomic.StoreUint64(&lastSessID, 1) + return 1 } return atomic.AddUint64(&lastSessID, 1) } func newTxID() uint64 { if atomic.LoadUint64(&lastTxID) == math.MaxUint64 { - atomic.StoreUint64(&lastTxID, 0) - return 0 + atomic.StoreUint64(&lastTxID, 1) + return 1 } return atomic.AddUint64(&lastTxID, 1) } + +func newOperationID() uint64 { + if atomic.LoadUint64(&lastOperationID) == math.MaxUint64 { + atomic.StoreUint64(&lastOperationID, 1) + return 1 + } + return atomic.AddUint64(&lastOperationID, 1) +} diff --git a/internal/sqladapter/exql/order_by_test.go b/internal/sqladapter/exql/order_by_test.go index 8a4ff471..f800c58b 100644 --- a/internal/sqladapter/exql/order_by_test.go +++ b/internal/sqladapter/exql/order_by_test.go @@ -19,6 +19,21 @@ func TestOrderBy(t *testing.T) { } } +func TestOrderByRaw(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: RawValue("CASE WHEN id IN ? THEN 0 ELSE 1 END")}, + ), + ) + + s := o.Compile(defaultTemplate) + e := `ORDER BY CASE WHEN id IN ? THEN 0 ELSE 1 END` + + if trim(s) != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + func TestOrderByDesc(t *testing.T) { o := JoinWithOrderBy( JoinSortColumns( diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 81774a2c..99c64405 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -76,7 +76,7 @@ func TestOpenMustSucceed(t *testing.T) { assert.NoError(t, err) } -func TestPreparedStatementsCache(t *testing.T) { +func TestStressPreparedStatementsCache(t *testing.T) { sess, err := Open(settings) assert.NoError(t, err) defer sess.Close() @@ -91,7 +91,11 @@ func TestPreparedStatementsCache(t *testing.T) { // The max number of elements we can have on our LRU is 128, if an statement // is evicted it will be marked as dead and will be closed only when no other // queries are using it. - const maxPreparedStatements = 128 * 2 + const preparedStatementsLRU = 128 + + // This number can be a bit greater than preparedStatementsLRU when executing + // a lot of concurrent statements. + const maxPreparedStatements = preparedStatementsLRU + 40 var wg sync.WaitGroup for i := 0; i < 1000; i++ { @@ -106,7 +110,8 @@ func TestPreparedStatementsCache(t *testing.T) { if err != nil { tFatal(err) } - if activeStatements := sqladapter.NumActiveStatements(); activeStatements > maxPreparedStatements { + activeStatements := sqladapter.NumActiveStatements() + if activeStatements > maxPreparedStatements { tFatal(fmt.Errorf("The number of active statements cannot exceed %d (got %d).", maxPreparedStatements, activeStatements)) } }(i) @@ -116,6 +121,11 @@ func TestPreparedStatementsCache(t *testing.T) { } wg.Wait() + + activeStatements := sqladapter.NumActiveStatements() + if activeStatements > preparedStatementsLRU { + t.Fatal(fmt.Errorf("The number of active statements cannot exceed %d (got %d).", preparedStatementsLRU, activeStatements)) + } } func TestTruncateAllCollections(t *testing.T) { @@ -1429,7 +1439,7 @@ func TestBuilder(t *testing.T) { assert.NotZero(t, all) } -func TestExhaustConnectionPool(t *testing.T) { +func TestExhaustConnectionPoolWithTransactions(t *testing.T) { if Adapter == "ql" { t.Skip("Currently not supported.") } @@ -1461,7 +1471,7 @@ func TestExhaustConnectionPool(t *testing.T) { // Requesting a new transaction session. start := time.Now() - tLogf("Tx: %d: NewTx") + tLogf("Tx: %d: NewTx", i) tx, err := sess.NewTx() if err != nil { tFatal(err) diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index d4ccfd7f..e96f80ac 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -22,6 +22,71 @@ func TestSelect(t *testing.T) { b.SelectFrom("artist").String(), ) + { + rawCase := db.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000, 2000}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000, 2000}, + sel.Arguments(), + ) + } + + { + rawCase := db.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000}, + sel.Arguments(), + ) + } + + { + rawCase := db.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + rawCase := db.Raw("CASE WHEN id IN (NULL) THEN 0 ELSE 1 END") + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + rawCase.Arguments(), + ) + } + + { + rawCase := db.Raw("CASE WHEN id IN (?, ?) THEN 0 ELSE 1 END", 1000, 2000) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000, 2000}, + rawCase.Arguments(), + ) + } + { sel := b.Select(db.Func("DISTINCT", "name")).From("artist") assert.Equal( @@ -49,15 +114,29 @@ func TestSelect(t *testing.T) { b.Select().From("artist").Where(db.Cond{1: db.Func("ANY", db.Raw("column"))}).String(), ) - assert.Equal( - `SELECT * FROM "artist" WHERE ("id" NOT IN ($1, $2))`, - b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{0, -1}}).String(), - ) + { + q := b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{0, -1}}) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" NOT IN ($1, $2))`, + q.String(), + ) + assert.Equal( + []interface{}{0, -1}, + q.Arguments(), + ) + } - assert.Equal( - `SELECT * FROM "artist" WHERE ("id" NOT IN ($1))`, - b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{-1}}).String(), - ) + { + q := b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{-1}}) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" NOT IN ($1))`, + q.String(), + ) + assert.Equal( + []interface{}{-1}, + q.Arguments(), + ) + } assert.Equal( `SELECT * FROM "artist" WHERE ("id" IN ($1, $2))`, @@ -305,7 +384,7 @@ func TestSelect(t *testing.T) { ) assert.Equal( - `SELECT * FROM "artist" WHERE ("id" IS NULL)`, + `SELECT * FROM "artist" WHERE ("id" IN (NULL))`, b.SelectFrom("artist").Where(db.Cond{"id": []int64{}}).String(), ) @@ -688,7 +767,7 @@ func TestUpdate(t *testing.T) { idSlice := []int64{} q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}) assert.Equal( - `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`, + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`, q.String(), ) assert.Equal( @@ -701,7 +780,7 @@ func TestUpdate(t *testing.T) { idSlice := []int64{} q := b.Update("artist").Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}).Set(db.Cond{"some_column": 10}) assert.Equal( - `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`, + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`, q.String(), ) assert.Equal( diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index e1ecd446..b410281f 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -26,33 +26,42 @@ func newTemplateWithUtils(template *exql.Template) *templateWithUtils { func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) { argn := 0 + argx := make([]interface{}, 0, len(args)) for i := 0; i < len(in); i++ { if in[i] == '?' { - if len(args) > argn { // we have arguments to match. - u := toInterfaceArguments(args[argn]) + if len(args) > argn { k := `?` - if len(u) > 1 { - // An array of arguments - k = `(?` + strings.Repeat(`, ?`, len(u)-1) + `)` - } else if len(u) == 1 { - if rawValue, ok := u[0].(db.RawValue); ok { - k = rawValue.Raw() - u = []interface{}{} + values, isSlice := toInterfaceArguments(args[argn]) + if isSlice { + if len(values) == 0 { + k = `(NULL)` + } else { + k = `(?` + strings.Repeat(`, ?`, len(values)-1) + `)` + } + } else { + if len(values) == 1 { + if rawValue, ok := values[0].(db.RawValue); ok { + k, values = rawValue.Raw(), nil + } + } else if len(values) == 0 { + k = `NULL` } } - lk := len(k) - if lk > 1 { + if k != `?` { in = in[:i] + k + in[i+1:] i += len(k) - 1 } - args = append(args[:argn], append(u, args[argn+1:]...)...) - argn += len(u) + + if len(values) > 0 { + argx = append(argx, values...) + } + argn++ } } } - return in, args + return in, argx } // ToWhereWithArguments converts the given parameters into a exql.Where @@ -154,7 +163,7 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, [] fnName := t.Name() fnArgs := []interface{}{} - args := toInterfaceArguments(t.Arguments()) + args, _ := toInterfaceArguments(t.Arguments()) fragments := []string{} for i := range args { frag, args := tu.PlaceholderValue(args[i]) @@ -169,33 +178,30 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, [] } // toInterfaceArguments converts the given value into an array of interfaces. -func toInterfaceArguments(value interface{}) (args []interface{}) { +func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) { + v := reflect.ValueOf(value) + if value == nil { - return nil + return nil, false } - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: + if v.Type().Kind() == reflect.Slice { var i, total int + + // Byte slice gets transformed into a string. if v.Type().Elem().Kind() == reflect.Uint8 { - return []interface{}{string(value.([]byte))} + return []interface{}{string(value.([]byte))}, false } + total = v.Len() - if total > 0 { - args = make([]interface{}, total) - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - return args + args = make([]interface{}, total) + for i = 0; i < total; i++ { + args[i] = v.Index(i).Interface() } - return nil - default: - args = []interface{}{value} + return args, true } - return args + return []interface{}{value}, false } // ToColumnValues converts the given conditions into a exql.ColumnValues struct. @@ -265,7 +271,7 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal // A function with one or more arguments. fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) columnValue.Value = exql.RawValue(expanded) args = append(args, fnArgs...) case db.RawValue: @@ -273,27 +279,33 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal columnValue.Value = exql.RawValue(expanded) args = append(args, rawArgs...) default: - v := toInterfaceArguments(value) + v, isSlice := toInterfaceArguments(value) - if v == nil { - // Nil value given. - columnValue.Value = sqlNull + if isSlice { if columnValue.Operator == "" { - columnValue.Operator = sqlIsOperator + columnValue.Operator = sqlInOperator } - } else { - if len(v) > 1 || reflect.TypeOf(value).Kind() == reflect.Slice { + if len(v) > 0 { // Array value given. columnValue.Value = exql.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } else { + // Single value given. + columnValue.Value = exql.RawValue(`(NULL)`) + } + args = append(args, v...) + } else { + if v == nil { + // Nil value given. + columnValue.Value = sqlNull if columnValue.Operator == "" { - columnValue.Operator = sqlInOperator + columnValue.Operator = sqlIsOperator } } else { - // Single value given. columnValue.Value = sqlPlaceholder + args = append(args, v...) } - args = append(args, v...) } + } // Using guessed operator if no operator was given. diff --git a/lib/sqlbuilder/fetch.go b/lib/sqlbuilder/fetch.go index b5bf577f..e3f6b1fa 100644 --- a/lib/sqlbuilder/fetch.go +++ b/lib/sqlbuilder/fetch.go @@ -81,6 +81,9 @@ func fetchRow(rows *sql.Rows, dst interface{}) error { // slice of structs given by the pointer `dst`. func fetchRows(rows *sql.Rows, dst interface{}) error { var err error + if rows == nil { + panic("rows cannot be nil") + } defer rows.Close() diff --git a/lib/sqlbuilder/placeholder_test.go b/lib/sqlbuilder/placeholder_test.go index 82f472cd..3f05da39 100644 --- a/lib/sqlbuilder/placeholder_test.go +++ b/lib/sqlbuilder/placeholder_test.go @@ -48,7 +48,7 @@ func TestPlaceholderArray(t *testing.T) { { ret, _ := expandPlaceholders("??", []interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}) - assert.Equal(t, "(?, ?, ?)?", ret) + assert.Equal(t, "(?, ?, ?)(NULL)", ret) } } diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 131911df..695748ee 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -35,7 +35,7 @@ type selector struct { groupBy *exql.GroupBy groupByArgs []interface{} - orderBy exql.OrderBy + orderBy *exql.OrderBy orderByArgs []interface{} limit exql.Limit @@ -161,7 +161,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { Column: exql.RawValue(col), } qs.mu.Lock() - qs.orderByArgs = args + qs.orderByArgs = append(qs.orderByArgs, args...) qs.mu.Unlock() case db.Function: fnName, fnArgs := value.Name(), value.Arguments() @@ -175,7 +175,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { Column: exql.RawValue(expanded), } qs.mu.Lock() - qs.orderByArgs = fnArgs + qs.orderByArgs = append(qs.orderByArgs, fnArgs...) qs.mu.Unlock() case string: if strings.HasPrefix(value, "-") { @@ -204,7 +204,9 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { } qs.mu.Lock() - qs.orderBy.SortColumns = &sortColumns + qs.orderBy = &exql.OrderBy{ + SortColumns: &sortColumns, + } qs.mu.Unlock() return qs @@ -332,7 +334,7 @@ func (qs *selector) statement() *exql.Statement { Offset: qs.offset, Joins: exql.JoinConditions(qs.joins...), Where: qs.where, - OrderBy: &qs.orderBy, + OrderBy: qs.orderBy, GroupBy: qs.groupBy, } } diff --git a/logger.go b/logger.go index 1703fb58..4f8d949f 100644 --- a/logger.go +++ b/logger.go @@ -47,8 +47,9 @@ var ( // QueryStatus represents the status of a query after being executed. type QueryStatus struct { - SessID uint64 - TxID uint64 + SessID uint64 + TxID uint64 + QueryID uint64 RowsAffected *int64 LastInsertID *int64 diff --git a/mysql/database.go b/mysql/database.go index 8263e40c..ac4f596e 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -120,7 +120,7 @@ func (d *database) open() error { return err } - if err := d.BaseDatabase.WaitForConnection(connFn); err != nil { + if err := d.BaseDatabase.WaitForConnection(connFn, true); err != nil { return err } @@ -141,7 +141,10 @@ func (d *database) clone() (*database, error) { } clone.Builder = b - clone.BaseDatabase.BindSession(d.BaseDatabase.Session()) + if err := clone.BaseDatabase.BindSession(d.BaseDatabase.Session()); err != nil { + return nil, err + } + return clone, nil } @@ -182,14 +185,19 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { } connFn := func() error { - sqlTx, err := clone.BaseDatabase.Session().Begin() + sess := clone.BaseDatabase.Session() + if sess == nil { + return db.ErrNotConnected + } + + sqlTx, err := sess.Begin() if err == nil { return clone.BindTx(sqlTx) } return err } - if err := d.BaseDatabase.WaitForConnection(connFn); err != nil { + if err := d.BaseDatabase.WaitForConnection(connFn, false); err != nil { return nil, err } diff --git a/postgresql/database.go b/postgresql/database.go index d7807d99..e87799f5 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -119,7 +119,7 @@ func (d *database) open() error { return err } - if err := d.BaseDatabase.WaitForConnection(connFn); err != nil { + if err := d.BaseDatabase.WaitForConnection(connFn, true); err != nil { return err } @@ -140,7 +140,10 @@ func (d *database) clone() (*database, error) { } clone.Builder = b - clone.BaseDatabase.BindSession(d.BaseDatabase.Session()) + if err := clone.BaseDatabase.BindSession(d.BaseDatabase.Session()); err != nil { + return nil, err + } + return clone, nil } @@ -153,10 +156,21 @@ func (d *database) CompileStatement(stmt *exql.Statement) string { // Err allows sqladapter to translate some known errors into generic errors. func (d *database) Err(err error) error { if err != nil { + // These errors are not exported so we have to check them by comparing + // string values. s := err.Error() - // These errors are not exported so we have to check them by they string value. - if strings.Contains(s, `too many clients`) || strings.Contains(s, `remaining connection slots are reserved`) || strings.Contains(s, `too many open`) { + switch { + case strings.Contains(s, `too many clients`), + strings.Contains(s, `remaining connection slots are reserved`), + strings.Contains(s, `too many open`): return db.ErrTooManyClients + case strings.Contains(s, `connection refused`), + strings.Contains(s, `reset by peer`), + strings.Contains(s, `is starting up`), + strings.Contains(s, `is in recovery mode`), + strings.Contains(s, `is closed`), + strings.Contains(s, `is shutting down`): + return db.ErrServerRefusedConnection } } return err @@ -184,17 +198,20 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { defer clone.txMu.Unlock() connFn := func() error { - sqlTx, err := clone.BaseDatabase.Session().Begin() + sess := clone.BaseDatabase.Session() + if sess == nil { + return db.ErrNotConnected + } + sqlTx, err := sess.Begin() if err == nil { return clone.BindTx(sqlTx) } return err } - if err := d.BaseDatabase.WaitForConnection(connFn); err != nil { + if err := d.BaseDatabase.WaitForConnection(connFn, false); err != nil { return nil, err } - return sqladapter.NewTx(clone), nil } diff --git a/postgresql/local_test.go b/postgresql/local_test.go index b09e86ad..47d68b76 100644 --- a/postgresql/local_test.go +++ b/postgresql/local_test.go @@ -2,6 +2,8 @@ package postgresql import ( "database/sql" + "os" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -117,3 +119,47 @@ func TestIssue210(t *testing.T) { _, err = sess.Collection("hello").Find().Count() assert.NoError(t, err) } + +// The "driver: bad connection" problem (driver.ErrBadConn) happens when the +// database process is abnormally interrupted, this can happen for a variety of +// reasons, for instance when the database process is OOM killed by the OS. +func TestDriverBadConnection(t *testing.T) { + sess := mustOpen() + defer sess.Close() + + db.Conf.SetRetryQueryOnError(true) + defer db.Conf.SetRetryQueryOnError(false) + + var tMu sync.Mutex + tFatal := func(err error) { + tMu.Lock() + defer tMu.Unlock() + t.Fatal(err) + } + + const concurrentOpts = 20 + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _, err := sess.Collection("artist").Find().Count() + if err != nil { + tFatal(err) + } + }(i) + if i%concurrentOpts == (concurrentOpts - 1) { + // This triggers the "bad connection" problem, if you want to see that + // instead of retrying the query set maxQueryRetryAttempts to 0. + wg.Add(1) + go func() { + defer wg.Done() + sess.Query("SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = ?", os.Getenv("DB_NAME")) + }() + wg.Wait() + } + } + + wg.Wait() +} diff --git a/ql/database.go b/ql/database.go index 3a5134cc..60e91f77 100644 --- a/ql/database.go +++ b/ql/database.go @@ -24,6 +24,7 @@ package ql import ( "database/sql" "errors" + "log" "sync" "sync/atomic" @@ -198,7 +199,7 @@ func (d *database) open() error { return errTooManyOpenFiles } - if err := d.BaseDatabase.WaitForConnection(openFn); err != nil { + if err := d.BaseDatabase.WaitForConnection(openFn, true); err != nil { return err } @@ -227,6 +228,7 @@ func (d *database) CompileStatement(stmt *exql.Statement) string { // Err allows sqladapter to translate some known errors into generic errors. func (d *database) Err(err error) error { if err != nil { + log.Printf("Err: %v", err) if err == errTooManyOpenFiles { return db.ErrTooManyClients } @@ -282,14 +284,18 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { defer clone.txMu.Unlock() openFn := func() error { - sqlTx, err := clone.BaseDatabase.Session().Begin() + sess := clone.BaseDatabase.Session() + if sess == nil { + return db.ErrNotConnected + } + sqlTx, err := sess.Begin() if err == nil { return clone.BindTx(sqlTx) } return err } - if err := d.BaseDatabase.WaitForConnection(openFn); err != nil { + if err := d.BaseDatabase.WaitForConnection(openFn, false); err != nil { return nil, err } diff --git a/sqlite/database.go b/sqlite/database.go index db824671..f6aa7380 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -140,7 +140,7 @@ func (d *database) open() error { return errTooManyOpenFiles } - if err := d.BaseDatabase.WaitForConnection(openFn); err != nil { + if err := d.BaseDatabase.WaitForConnection(openFn, true); err != nil { return err } @@ -195,6 +195,10 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { } openFn := func() error { + sess := clone.BaseDatabase.Session() + if sess == nil { + return db.ErrNotConnected + } sqlTx, err := clone.BaseDatabase.Session().Begin() if err == nil { return clone.BindTx(sqlTx) @@ -202,7 +206,7 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { return err } - if err := d.BaseDatabase.WaitForConnection(openFn); err != nil { + if err := d.BaseDatabase.WaitForConnection(openFn, true); err != nil { return nil, err }