diff --git a/database/driver.go b/database/driver.go index 7f60962d..e7ecb442 100644 --- a/database/driver.go +++ b/database/driver.go @@ -147,9 +147,6 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx // Prepare a statement. func (c *Conn) Prepare(query string) (stmt driver.Stmt, err error) { - if c.needsMutex(query) { - c.acquire() - } stmt, err = c.wrapped.Prepare(query) stmt = &Stmt{ conn: c, @@ -161,9 +158,6 @@ func (c *Conn) Prepare(query string) (stmt driver.Stmt, err error) { // PrepareContext prepares a statement with context. func (c *Conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { - if c.needsMutex(query) { - c.acquire() - } if p, cast := c.wrapped.(driver.ConnPrepareContext); cast { stmt, err = p.PrepareContext(ctx, query) } else { @@ -230,7 +224,6 @@ type Stmt struct { // Close the statement. func (s *Stmt) Close() (err error) { - defer s.conn.release() err = s.wrapped.Close() return } @@ -243,12 +236,20 @@ func (s *Stmt) NumInput() (n int) { // Exec executes the statement. func (s *Stmt) Exec(args []driver.Value) (r driver.Result, err error) { + if s.needsMutex() { + s.conn.acquire() + defer s.conn.release() + } r, err = s.wrapped.Exec(args) return } // ExecContext executes the statement with context. func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { + if s.needsMutex() { + s.conn.acquire() + defer s.conn.release() + } if p, cast := s.wrapped.(driver.StmtExecContext); cast { r, err = p.ExecContext(ctx, args) } else { @@ -259,12 +260,20 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r dri // Query executes a query. func (s *Stmt) Query(args []driver.Value) (r driver.Rows, err error) { + if s.needsMutex() { + s.conn.acquire() + defer s.conn.release() + } r, err = s.wrapped.Query(args) return } // QueryContext executes a query. func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) { + if s.needsMutex() { + s.conn.acquire() + defer s.conn.release() + } if p, cast := s.wrapped.(driver.StmtQueryContext); cast { r, err = p.QueryContext(ctx, args) } else { @@ -281,6 +290,12 @@ func (s *Stmt) values(named []driver.NamedValue) (out []driver.Value) { return } +// needsMutex returns true when the query should is a write operation. +func (s *Stmt) needsMutex() (matched bool) { + matched = s.conn.needsMutex(s.query) + return +} + // Tx is a transaction. type Tx struct { wrapped driver.Tx