Skip to content

Commit

Permalink
mutex stmt execution.
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff Ortel <[email protected]>
  • Loading branch information
jortel committed Oct 9, 2024
1 parent 9af9450 commit cb4ebee
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions database/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,6 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx

// Prepare a statement.
func (c *Conn) Prepare(query string) (stmt driver.Stmt, err error) {
if c.needsMutex(query) {
c.acquire()
}
stmt, err = c.wrapped.Prepare(query)
stmt = &Stmt{
conn: c,
Expand All @@ -161,9 +158,6 @@ func (c *Conn) Prepare(query string) (stmt driver.Stmt, err error) {

// PrepareContext prepares a statement with context.
func (c *Conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
if c.needsMutex(query) {
c.acquire()
}
if p, cast := c.wrapped.(driver.ConnPrepareContext); cast {
stmt, err = p.PrepareContext(ctx, query)
} else {
Expand Down Expand Up @@ -230,7 +224,6 @@ type Stmt struct {

// Close the statement.
func (s *Stmt) Close() (err error) {
defer s.conn.release()
err = s.wrapped.Close()
return
}
Expand All @@ -243,12 +236,20 @@ func (s *Stmt) NumInput() (n int) {

// Exec executes the statement.
func (s *Stmt) Exec(args []driver.Value) (r driver.Result, err error) {
if s.needsMutex() {
s.conn.acquire()
defer s.conn.release()
}
r, err = s.wrapped.Exec(args)
return
}

// ExecContext executes the statement with context.
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
if s.needsMutex() {
s.conn.acquire()
defer s.conn.release()
}
if p, cast := s.wrapped.(driver.StmtExecContext); cast {
r, err = p.ExecContext(ctx, args)
} else {
Expand All @@ -259,12 +260,20 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r dri

// Query executes a query.
func (s *Stmt) Query(args []driver.Value) (r driver.Rows, err error) {
if s.needsMutex() {
s.conn.acquire()
defer s.conn.release()
}
r, err = s.wrapped.Query(args)
return
}

// QueryContext executes a query.
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) {
if s.needsMutex() {
s.conn.acquire()
defer s.conn.release()
}
if p, cast := s.wrapped.(driver.StmtQueryContext); cast {
r, err = p.QueryContext(ctx, args)
} else {
Expand All @@ -281,6 +290,12 @@ func (s *Stmt) values(named []driver.NamedValue) (out []driver.Value) {
return
}

// needsMutex returns true when the query should is a write operation.
func (s *Stmt) needsMutex() (matched bool) {
matched = s.conn.needsMutex(s.query)
return
}

// Tx is a transaction.
type Tx struct {
wrapped driver.Tx
Expand Down

0 comments on commit cb4ebee

Please sign in to comment.