@@ -2,49 +2,51 @@ package pgfsm
2
2
3
3
import (
4
4
"context"
5
- "database/sql"
6
5
_ "embed"
7
6
"errors"
8
7
"log/slog"
9
8
"strings"
9
+
10
+ "github.com/jackc/pgx/v5"
11
+ "github.com/jackc/pgx/v5/pgxpool"
10
12
)
11
13
12
- func transaction (ctx context.Context , db * sql. DB , fn func (ctx context.Context , tx * sql .Tx ) error ) error {
13
- tx , err := db .BeginTx (ctx , & sql. TxOptions {} )
14
+ func transaction (ctx context.Context , db * pgxpool. Pool , fn func (ctx context.Context , tx pgx .Tx ) error ) error {
15
+ tx , err := db .Begin (ctx )
14
16
if err != nil {
15
17
return err
16
18
}
17
19
18
20
if err = fn (ctx , tx ); err != nil {
19
- txErr := tx .Rollback ()
20
- if errors .Is (txErr , sql . ErrTxDone ) {
21
+ txErr := tx .Rollback (ctx )
22
+ if errors .Is (txErr , pgx . ErrTxClosed ) {
21
23
return err
22
24
}
23
25
24
26
return errors .Join (err , txErr )
25
27
}
26
28
27
- err = tx .Commit ()
28
- if errors .Is (err , sql . ErrTxDone ) {
29
+ err = tx .Commit (ctx )
30
+ if errors .Is (err , pgx . ErrTxClosed ) {
29
31
return nil
30
32
}
31
33
32
34
return err
33
35
}
34
36
35
- func insert (ctx context.Context , tx * sql .Tx , encoder Encoding , cmd Command ) error {
37
+ func insert (ctx context.Context , tx pgx .Tx , encoder Encoding , cmd Command ) error {
36
38
data , err := encoder .Encode (cmd )
37
39
if err != nil {
38
40
return err
39
41
}
40
42
41
43
const q = `INSERT INTO pgfsm.command (kind, data) VALUES ($1, $2)`
42
44
43
- _ , err = tx .ExecContext (ctx , q , cmd .Kind (), data )
45
+ _ , err = tx .Exec (ctx , q , cmd .Kind (), data )
44
46
return err
45
47
}
46
48
47
- func next (ctx context.Context , tx * sql .Tx ) (int64 , string , []byte , error ) {
49
+ func next (ctx context.Context , tx pgx .Tx ) (int64 , string , []byte , error ) {
48
50
const q = `
49
51
SELECT id, kind, data FROM pgfsm.command
50
52
ORDER BY id ASC
@@ -58,24 +60,24 @@ func next(ctx context.Context, tx *sql.Tx) (int64, string, []byte, error) {
58
60
data []byte
59
61
)
60
62
61
- if err := tx .QueryRowContext (ctx , q ).Scan (& id , & kind , & data ); err != nil {
63
+ if err := tx .QueryRow (ctx , q ).Scan (& id , & kind , & data ); err != nil {
62
64
return 0 , "" , []byte {}, err
63
65
}
64
66
65
67
return id , kind , data , nil
66
68
}
67
69
68
- func remove (ctx context.Context , tx * sql .Tx , id int64 ) error {
70
+ func remove (ctx context.Context , tx pgx .Tx , id int64 ) error {
69
71
const q = `DELETE FROM pgfsm.command WHERE id = $1`
70
72
71
- _ , err := tx .ExecContext (ctx , q , id )
73
+ _ , err := tx .Exec (ctx , q , id )
72
74
return err
73
75
}
74
76
75
77
//go:embed migrate.sql
76
78
var migration string
77
79
78
- func migrateUp (ctx context.Context , db * sql. DB , logger * slog.Logger ) error {
80
+ func migrateUp (ctx context.Context , db * pgxpool. Pool , logger * slog.Logger ) error {
79
81
logger .DebugContext (ctx , "performing migrations" )
80
82
81
83
statements := strings .Split (migration , ";" )
@@ -89,7 +91,7 @@ func migrateUp(ctx context.Context, db *sql.DB, logger *slog.Logger) error {
89
91
With (slog .String ("statement" , statement )).
90
92
DebugContext (ctx , "executing statement" )
91
93
92
- if _ , err := db .ExecContext (ctx , statement ); err != nil {
94
+ if _ , err := db .Exec (ctx , statement ); err != nil {
93
95
return err
94
96
}
95
97
}
0 commit comments