diff --git a/database.go b/database.go index f240a4d..b2b46f4 100644 --- a/database.go +++ b/database.go @@ -24,7 +24,7 @@ type database struct { // Other things pc singleflightx.Group[string, []byte] // persistant cache pcm sync.RWMutex // post creation - sp sync.Mutex // short path creation + sp singleflightx.Group[string, string] // short path creation debug bool } @@ -200,64 +200,6 @@ func (db *database) QueryRowContext(ctx context.Context, query string, args ...a return db.readDb.QueryRowContext(ctx, query, args...), nil } -type transaction struct { - tx *sql.Tx - db *database -} - -func (db *database) Begin() (*transaction, error) { - return db.BeginTx(context.Background(), nil) -} - -func (db *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*transaction, error) { - if db == nil || db.writeDb == nil { - return nil, errors.New("database not initialized") - } - tx, err := db.writeDb.BeginTx(ctx, opts) - if err != nil { - return nil, err - } - return &transaction{tx: tx, db: db}, nil -} - -func (tx *transaction) Exec(query string, args ...any) (sql.Result, error) { - return tx.ExecContext(context.Background(), query, args...) -} - -func (tx *transaction) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { - ctx = tx.db.dbBefore(ctx, query, args...) - defer tx.db.dbAfter(ctx, query, args...) - return tx.tx.ExecContext(ctx, query, args...) -} - -func (tx *transaction) Query(query string, args ...any) (*sql.Rows, error) { - return tx.QueryContext(context.Background(), query, args...) -} - -func (tx *transaction) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { - ctx = tx.db.dbBefore(ctx, query, args...) - defer tx.db.dbAfter(ctx, query, args...) - return tx.tx.QueryContext(ctx, query, args...) -} - -func (tx *transaction) QueryRow(query string, args ...any) *sql.Row { - return tx.QueryRowContext(context.Background(), query, args...) -} - -func (tx *transaction) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { - ctx = tx.db.dbBefore(ctx, query, args...) - defer tx.db.dbAfter(ctx, query, args...) - return tx.tx.QueryRowContext(ctx, query, args...) -} - -func (tx *transaction) Commit() error { - return tx.tx.Commit() -} - -func (tx *transaction) Rollback() error { - return tx.tx.Rollback() -} - func (db *database) dump(file string) { if db == nil || db.readDb == nil { return diff --git a/shortPath.go b/shortPath.go index db960b6..8a80fa7 100644 --- a/shortPath.go +++ b/shortPath.go @@ -10,31 +10,41 @@ func (db *database) shortenPath(p string) (string, error) { if p == "" { return "", errors.New("empty path") } - - db.sp.Lock() - defer db.sp.Unlock() - - result, err := db.shortenPathTransaction(p) - if err != nil { + result, err, _ := db.sp.Do(p, func() (string, error) { + sp, err := db.queryShortPath(p) + if err == nil && sp != "" { + return sp, nil + } else if !errors.Is(err, sql.ErrNoRows) { + return "", err + } + // In case it wasn't shortened yet ... + err = db.createShortPath(p) + if err != nil { + return "", err + } + // Query again + sp, err = db.queryShortPath(p) + if err == nil && sp != "" { + return sp, nil + } return "", err - } - - return result, nil + }) + return result, err } -func (db *database) shortenPathTransaction(p string) (string, error) { - tx, err := db.Begin() +func (db *database) queryShortPath(p string) (string, error) { + row, err := db.QueryRow("select id from shortpath where path = @path", sql.Named("path", p)) if err != nil { return "", err } - defer tx.Rollback() - var id int64 - err = tx.QueryRow("select id from shortpath where path = @path", sql.Named("path", p)).Scan(&id) - if err == sql.ErrNoRows { - // Path doesn't exist, insert new entry with the lowest available id - err = tx.QueryRow(` - WITH RECURSIVE ids(n) AS ( + err = row.Scan(&id) + return fmt.Sprintf("/s/%x", id), err +} + +func (db *database) createShortPath(p string) error { + _, err := db.Exec(` + WITH RECURSIVE ids(n) AS ( SELECT 1 UNION ALL SELECT n + 1 FROM ids @@ -48,19 +58,6 @@ func (db *database) shortenPathTransaction(p string) (string, error) { ORDER BY n LIMIT 1), ? - ) - RETURNING id - `, p).Scan(&id) - if err != nil { - return "", err - } - } else if err != nil { - return "", err - } - - if err = tx.Commit(); err != nil { - return "", err - } - - return fmt.Sprintf("/s/%x", id), nil + )`, p) + return err }