Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff Ortel <[email protected]>
  • Loading branch information
jortel committed Oct 9, 2024
1 parent 8a3e80c commit 9af9450
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 35 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ endif
test:
go test -count=1 -v $(shell go list ./... | grep -v "hub/test")

test-db:
go test -count=1 -timeout=6h -v ./database...

# Run Hub REST API tests.
test-api:
HUB_BASE_URL=$(HUB_BASE_URL) go test -count=1 -p=1 -v -failfast ./test/api/...
Expand Down
100 changes: 65 additions & 35 deletions database/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ import (
"github.com/mattn/go-sqlite3"
)

// Driver is a wrapper around the SQLite driver.
// The purpose is to prevent database locked errors using
// a mutex around write operations.
type Driver struct {
mutex sync.Mutex
wrapped driver.Driver
dsn string
}

// Open a connection.
func (d *Driver) Open(dsn string) (conn driver.Conn, err error) {
d.wrapped = &sqlite3.SQLiteDriver{}
conn, err = d.wrapped.Open(dsn)
Expand All @@ -28,41 +32,56 @@ func (d *Driver) Open(dsn string) (conn driver.Conn, err error) {
return
}

// OpenConnector opens a connection.
func (d *Driver) OpenConnector(dsn string) (dc driver.Connector, err error) {
d.dsn = dsn
dc = d
return
}

// Connect opens a connection.
func (d *Driver) Connect(context.Context) (conn driver.Conn, err error) {
conn, err = d.Open(d.dsn)
return
}

// Driver returns the underlying driver.
func (d *Driver) Driver() driver.Driver {
return d
}

// Conn is a DB connection.
type Conn struct {
mutex *sync.Mutex
wrapped driver.Conn
hasMutex bool
hasTx bool
}

// Ping the DB.
func (c *Conn) Ping(ctx context.Context) (err error) {
if p, cast := c.wrapped.(driver.Pinger); cast {
err = p.Ping(ctx)
}
return
}

// ResetSession reset the connection.
// - Reset the Tx.
// - Release the mutex.
func (c *Conn) ResetSession(ctx context.Context) (err error) {
defer c.release()
defer func() {
c.hasTx = false
c.release()
}()
if p, cast := c.wrapped.(driver.SessionResetter); cast {
err = p.ResetSession(ctx)
}
return
}

// IsValid returns true when the connection is valid.
// When true, the connection may be reused by the sql package.
func (c *Conn) IsValid() (b bool) {
b = true
if p, cast := c.wrapped.(driver.Validator); cast {
Expand All @@ -71,6 +90,7 @@ func (c *Conn) IsValid() (b bool) {
return
}

// QueryContext execute a query with context.
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Rows, err error) {
defer c.release()
if c.needsMutex(query) {
Expand All @@ -82,6 +102,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return
}

// ExecContext executes an SQL/DDL statement with context.
func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
defer c.release()
if c.needsMutex(query) {
Expand All @@ -93,6 +114,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name
return
}

// Begin a transaction.
func (c *Conn) Begin() (tx driver.Tx, err error) {
c.acquire()
tx, err = c.wrapped.Begin()
Expand All @@ -103,9 +125,11 @@ func (c *Conn) Begin() (tx driver.Tx, err error) {
conn: c,
wrapped: tx,
}
c.hasTx = true
return
}

// BeginTx begins a transaction.
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
c.acquire()
if p, cast := c.wrapped.(driver.ConnBeginTx); cast {
Expand All @@ -117,9 +141,11 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx
conn: c,
wrapped: tx,
}
c.hasTx = true
return
}

// Prepare a statement.
func (c *Conn) Prepare(query string) (stmt driver.Stmt, err error) {
if c.needsMutex(query) {
c.acquire()
Expand All @@ -128,10 +154,12 @@ func (c *Conn) Prepare(query string) (stmt driver.Stmt, err error) {
stmt = &Stmt{
conn: c,
wrapped: stmt,
query: query,
}
return
}

// PrepareContext prepares a statement with context.
func (c *Conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
if c.needsMutex(query) {
c.acquire()
Expand All @@ -144,16 +172,20 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (stmt driver.St
stmt = &Stmt{
conn: c,
wrapped: stmt,
query: query,
}
return
}

// Close the connection.
func (c *Conn) Close() (err error) {
err = c.wrapped.Close()
c.hasMutex = false
c.release()
return
}

// needsMutex returns true when the query should is a write operation.
func (c *Conn) needsMutex(query string) (matched bool) {
if query == "" {
return
Expand All @@ -168,43 +200,55 @@ func (c *Conn) needsMutex(query string) (matched bool) {
return
}

// acquire the mutex.
// Since Locks are not reentrant, the mutex is acquired
// only if this connection has not already acquired it.
func (c *Conn) acquire() {
if !c.hasMutex {
c.mutex.Lock()
c.hasMutex = true
}
}

// release the mutex.
// Released only when:
// - This connection has acquired it
// - Not in a transaction.
func (c *Conn) release() {
if c.hasMutex {
if c.hasMutex && !c.hasTx {
c.mutex.Unlock()
c.hasMutex = false
}
}

// Stmt is a SQL/DDL statement.
type Stmt struct {
wrapped driver.Stmt
conn *Conn
query string
}

// Close the statement.
func (s *Stmt) Close() (err error) {
defer s.conn.release()
err = s.wrapped.Close()
return
}

// NumInput returns the number of (query) input parameters.
func (s *Stmt) NumInput() (n int) {
n = s.wrapped.NumInput()
return
}

// Exec executes the statement.
func (s *Stmt) Exec(args []driver.Value) (r driver.Result, err error) {
defer s.conn.release()
r, err = s.wrapped.Exec(args)
return
}

// ExecContext executes the statement with context.
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
defer s.conn.release()
if p, cast := s.wrapped.(driver.StmtExecContext); cast {
r, err = p.ExecContext(ctx, args)
} else {
Expand All @@ -213,69 +257,55 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r dri
return
}

// Query executes a query.
func (s *Stmt) Query(args []driver.Value) (r driver.Rows, err error) {
r, err = s.wrapped.Query(args)
r = &Rows{
conn: s.conn,
wrapped: r,
}
return
}

// QueryContext executes a query.
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) {
if p, cast := s.wrapped.(driver.StmtQueryContext); cast {
r, err = p.QueryContext(ctx, args)
} else {
r, err = s.Query(s.values(args))
}
r = &Rows{
conn: s.conn,
wrapped: r,
}
return
}

// values converts named-values to values.
func (s *Stmt) values(named []driver.NamedValue) (out []driver.Value) {
for i := range named {
out = append(out, named[i].Value)
}
return
}

// Tx is a transaction.
type Tx struct {
wrapped driver.Tx
conn *Conn
}

// Commit the transaction.
// Releases the mutex.
func (t *Tx) Commit() (err error) {
defer t.conn.release()
defer func() {
t.conn.hasTx = false
t.conn.release()
}()
err = t.wrapped.Commit()
return
}

//
// Rollback the transaction.
// Releases the mutex.
func (t *Tx) Rollback() (err error) {
defer t.conn.release()
defer func() {
t.conn.hasTx = false
t.conn.release()
}()
err = t.wrapped.Rollback()
return
}

type Rows struct {
conn *Conn
wrapped driver.Rows
}

func (r *Rows) Columns() (s []string) {
s = r.wrapped.Columns()
return
}

func (r *Rows) Close() (err error) {
defer r.conn.release()
err = r.wrapped.Close()
return
}

func (r *Rows) Next(object []driver.Value) (err error) {
err = r.wrapped.Next(object)
return
}

0 comments on commit 9af9450

Please sign in to comment.