Skip to content

Commit 381d74a

Browse files
authored
Concurrency support (#5)
This commit provides a new option named SetConcurrency that allows the user of the package to specify the number of queued commands to process concurrently. This effectively adds batching with some slight differences. Mainly, commands in a batch are all managed within a single database transaction, so are dependent on each other's success. If any one command in the batch fails then all commands in the batch will return to the queue. This makes idempotency quite important. I wanted to add this kind of batching as I figured you could end up in situations where all instances of the FSM get blocked waiting on a chain of dependent commands, so allowing nodes to process large amounts of commands in a batch can help mitigate that problem. Signed-off-by: David Bond <[email protected]>
1 parent 1d27da3 commit 381d74a

File tree

4 files changed

+119
-107
lines changed

4 files changed

+119
-107
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ any other tables applications built on top of this package may use.
1717

1818
Commands use an ordinal identifier to ensure that they're processed in order of insertion. This is an implementation
1919
detail not exposed to users of the package. Once processed, their entry in the `pgfsm.command` table is removed,
20-
providing once-only processing of individual commands.
20+
providing once-only processing of individual commands. When using a concurrency of greater than one, this package
21+
provides at-least-once processing of commands in batches. Batches of commands are dependent on eachother as they are
22+
handled within the same transaction. If a single command within the batch fails, all commands in the batch are returned
23+
to the database.
2124

22-
When reading commands, a transaction is used to ensure that should command processing fail, the contents of the command
23-
are returned to the `pgfsm.command` table with the same identifier. This package uses a registration system for commands
24-
utilising parameterised types that allows commands to be directly decoded into their concrete types without excessive
25-
usage of reflection.
25+
This package uses a registration system for commands utilising parameterised types that allows commands to be directly
26+
decoded into their concrete types without excessive usage of reflection.
2627

2728
When handling a command, you have the option to return a command as a result. This allows commands to act as a graph of
2829
sorts. Where the successful processing of one command creates zero or more child commands. This can be used to define

database.go

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
_ "embed"
66
"errors"
7-
"log/slog"
87
"strings"
98

109
"github.com/jackc/pgx/v5"
@@ -46,51 +45,55 @@ func insert(ctx context.Context, tx pgx.Tx, encoder Encoding, cmd Command) error
4645
return err
4746
}
4847

49-
func next(ctx context.Context, tx pgx.Tx) (int64, string, []byte, error) {
50-
const q = `
51-
SELECT id, kind, data FROM pgfsm.command
52-
ORDER BY id ASC
53-
FOR UPDATE SKIP LOCKED
54-
LIMIT 1
55-
`
56-
57-
var (
58-
id int64
48+
type (
49+
record struct {
50+
id uint64
5951
kind string
6052
data []byte
61-
)
53+
}
54+
)
55+
56+
func next(ctx context.Context, tx pgx.Tx, limit uint) ([]record, error) {
57+
const q = `
58+
DELETE FROM pgfsm.command WHERE id IN (
59+
SELECT id FROM pgfsm.command
60+
ORDER BY id ASC
61+
FOR UPDATE SKIP LOCKED
62+
LIMIT $1
63+
) RETURNING id, kind, data
64+
`
6265

63-
if err := tx.QueryRow(ctx, q).Scan(&id, &kind, &data); err != nil {
64-
return 0, "", []byte{}, err
66+
rows, err := tx.Query(ctx, q, limit)
67+
if err != nil {
68+
return nil, err
6569
}
6670

67-
return id, kind, data, nil
68-
}
71+
records := make([]record, 0, limit)
72+
defer rows.Close()
6973

70-
func remove(ctx context.Context, tx pgx.Tx, id int64) error {
71-
const q = `DELETE FROM pgfsm.command WHERE id = $1`
74+
for rows.Next() {
75+
var r record
76+
if err = rows.Scan(&r.id, &r.kind, &r.data); err != nil {
77+
return nil, err
78+
}
7279

73-
_, err := tx.Exec(ctx, q, id)
74-
return err
80+
records = append(records, r)
81+
}
82+
83+
return records, rows.Err()
7584
}
7685

7786
//go:embed migrate.sql
7887
var migration string
7988

80-
func migrateUp(ctx context.Context, db *pgxpool.Pool, logger *slog.Logger) error {
81-
logger.DebugContext(ctx, "performing migrations")
82-
89+
func migrateUp(ctx context.Context, db *pgxpool.Pool) error {
8390
statements := strings.Split(migration, ";")
8491

8592
for _, statement := range statements {
8693
if strings.TrimSpace(statement) == "" {
8794
continue
8895
}
8996

90-
logger.
91-
With(slog.String("statement", statement)).
92-
DebugContext(ctx, "executing statement")
93-
9497
if _, err := db.Exec(ctx, statement); err != nil {
9598
return err
9699
}

pgfsm.go

Lines changed: 72 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ import (
1818
"context"
1919
"errors"
2020
"fmt"
21-
"log/slog"
21+
"sync"
2222
"time"
2323

2424
"github.com/jackc/pgx/v5"
2525
"github.com/jackc/pgx/v5/pgxpool"
26+
"golang.org/x/sync/errgroup"
2627
)
2728

2829
type (
@@ -79,7 +80,7 @@ func New(ctx context.Context, db *pgxpool.Pool, options ...Option) (*FSM, error)
7980
o(&opts)
8081
}
8182

82-
if err := migrateUp(ctx, db, opts.logger); err != nil {
83+
if err := migrateUp(ctx, db); err != nil {
8384
return nil, err
8485
}
8586

@@ -113,19 +114,11 @@ func (fsm *FSM) Write(ctx context.Context, cmd Command) error {
113114
switch command := cmd.(type) {
114115
case batchCommand:
115116
for _, cmd = range command {
116-
fsm.options.logger.
117-
With(slog.String("command_kind", cmd.Kind())).
118-
InfoContext(ctx, "writing command")
119-
120117
if err := insert(ctx, tx, fsm.options.encoder, cmd); err != nil {
121118
return err
122119
}
123120
}
124121
default:
125-
fsm.options.logger.
126-
With(slog.String("command_kind", cmd.Kind())).
127-
InfoContext(ctx, "writing command")
128-
129122
return insert(ctx, tx, fsm.options.encoder, cmd)
130123
}
131124

@@ -174,76 +167,92 @@ var (
174167
)
175168

176169
func (fsm *FSM) next(ctx context.Context, h Handler) error {
177-
fsm.options.logger.DebugContext(ctx, "checking for new commands")
178-
179170
return transaction(ctx, fsm.db, func(ctx context.Context, tx pgx.Tx) error {
180-
id, kind, data, err := next(ctx, tx)
171+
records, err := next(ctx, tx, fsm.options.concurrency)
181172
switch {
182-
case errors.Is(err, pgx.ErrNoRows):
173+
case len(records) == 0:
183174
return errNoCommands
184175
case err != nil:
185176
return err
186177
}
187178

188-
log := fsm.options.logger.With(
189-
slog.String("command_kind", kind),
190-
slog.Int64("command_id", id),
191-
)
192-
193-
factory, ok := fsm.commandFactories[kind]
194-
switch {
195-
case !ok && fsm.options.skipUnknownCommands:
196-
log.WarnContext(ctx, "skipping unknown command")
197-
return remove(ctx, tx, id)
198-
case !ok && !fsm.options.skipUnknownCommands:
199-
return UnknownCommandError{Kind: kind}
200-
}
179+
return fsm.handleCommands(ctx, tx, records, h)
180+
})
181+
}
201182

202-
cmd := factory()
203-
if err = fsm.options.encoder.Decode(data, cmd); err != nil {
204-
return fmt.Errorf("failed to decode command %d: %w", id, err)
205-
}
183+
func (fsm *FSM) handleCommands(ctx context.Context, tx pgx.Tx, records []record, h Handler) error {
184+
var mux sync.Mutex
185+
commands := make([]Command, 0)
186+
group, gCtx := errgroup.WithContext(ctx)
206187

207-
log.InfoContext(ctx, "handling command")
208-
returned, err := h(ctx, cmd)
209-
if err != nil {
210-
log.ErrorContext(ctx, "error handling command")
211-
return err
212-
}
188+
for _, r := range records {
189+
group.Go(func() error {
190+
cmd, err := fsm.handleCommand(gCtx, r, h)
191+
switch {
192+
case err != nil:
193+
return err
194+
case cmd == nil:
195+
return nil
196+
}
213197

214-
if returned != nil {
215-
switch command := returned.(type) {
198+
mux.Lock()
199+
defer mux.Unlock()
200+
switch command := cmd.(type) {
216201
case batchCommand:
217202
for _, batched := range command {
218-
log.With(slog.String("received_command_kind", batched.Kind())).
219-
InfoContext(ctx, "received additional command")
220-
221-
if err = insert(ctx, tx, fsm.options.encoder, batched); err != nil {
222-
return err
223-
}
203+
commands = append(commands, batched)
224204
}
225205

226206
default:
227-
log.With(slog.String("received_command_kind", returned.Kind())).
228-
InfoContext(ctx, "received additional command")
229-
230-
if err = insert(ctx, tx, fsm.options.encoder, command); err != nil {
231-
return err
232-
}
207+
commands = append(commands, command)
233208
}
209+
210+
return nil
211+
})
212+
}
213+
214+
if err := group.Wait(); err != nil {
215+
return err
216+
}
217+
218+
for _, cmd := range commands {
219+
if err := insert(ctx, tx, fsm.options.encoder, cmd); err != nil {
220+
return err
234221
}
222+
}
235223

236-
return remove(ctx, tx, id)
237-
})
224+
return nil
225+
}
226+
227+
func (fsm *FSM) handleCommand(ctx context.Context, r record, h Handler) (Command, error) {
228+
factory, ok := fsm.commandFactories[r.kind]
229+
switch {
230+
case !ok && fsm.options.skipUnknownCommands:
231+
return nil, nil
232+
case !ok && !fsm.options.skipUnknownCommands:
233+
return nil, UnknownCommandError{Kind: r.kind}
234+
}
235+
236+
cmd := factory()
237+
if err := fsm.options.encoder.Decode(r.data, cmd); err != nil {
238+
return nil, fmt.Errorf("failed to decode command %d: %w", r.id, err)
239+
}
240+
241+
returned, err := h(ctx, cmd)
242+
if err != nil {
243+
return nil, err
244+
}
245+
246+
return returned, nil
238247
}
239248

240249
type (
241250
options struct {
242251
skipUnknownCommands bool
243252
encoder Encoding
244-
logger *slog.Logger
245253
minPollInterval time.Duration
246254
maxPollInterval time.Duration
255+
concurrency uint
247256
}
248257

249258
// The Option type is a function that can modify the behaviour of the FSM.
@@ -254,9 +263,9 @@ func defaultOptions() options {
254263
return options{
255264
skipUnknownCommands: false,
256265
encoder: &JSON{},
257-
logger: slog.New(slog.DiscardHandler),
258266
minPollInterval: time.Millisecond,
259267
maxPollInterval: time.Second * 5,
268+
concurrency: 1,
260269
}
261270
}
262271

@@ -278,6 +287,14 @@ func UseEncoding(e Encoding) Option {
278287
return func(o *options) { o.encoder = e }
279288
}
280289

290+
// SetConcurrency is an Option implementation that changes the number of commands that can be processed at the same time.
291+
// It defaults to 1. Concurrent commands are processed within the same transaction. This means that within a batch of
292+
// commands, should any fail then all commands within the batch will be returned to the database and any subsequent
293+
// commands rolled back.
294+
func SetConcurrency(concurrency uint) Option {
295+
return func(o *options) { o.concurrency = concurrency }
296+
}
297+
281298
// PollInterval is an Option implementation that configures the minimum and maximum frequency at which Command implementations
282299
// will be read from the database. Each time the FSM checks for commands and finds none, it will half the frequency at
283300
// which it checks up to the maximum value. This is done to prevent excessive load on the database at times where there
@@ -289,12 +306,6 @@ func PollInterval(min, max time.Duration) Option {
289306
}
290307
}
291308

292-
// UseLogger is an Option implementation that modifies the logger used by the FSM. By default, the FSM uses
293-
// slog.DiscardHandler and will not write any logs.
294-
func UseLogger(l *slog.Logger) Option {
295-
return func(o *options) { o.logger = l.WithGroup("pgfsm") }
296-
}
297-
298309
type (
299310
batchCommand []Command
300311
)

pgfsm_test.go

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@ package pgfsm_test
33
import (
44
"context"
55
"errors"
6-
"log/slog"
76
"net/url"
8-
"os"
7+
"sync/atomic"
98
"testing"
109
"time"
1110

@@ -48,20 +47,18 @@ func TestFSM_ReadWrite(t *testing.T) {
4847

4948
db := testDB(t)
5049

51-
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
52-
5350
fsm, err := pgfsm.New(t.Context(), db,
5451
pgfsm.SkipUnknownCommands(),
5552
pgfsm.UseEncoding(&pgfsm.GOB{}),
56-
pgfsm.UseLogger(logger),
5753
pgfsm.PollInterval(time.Millisecond, time.Minute),
54+
pgfsm.SetConcurrency(10),
5855
)
5956
require.NoError(t, err)
6057

6158
var (
62-
handledA bool
63-
handledB bool
64-
handledC bool
59+
handledA atomic.Bool
60+
handledB atomic.Bool
61+
handledC atomic.Bool
6562
)
6663

6764
ctx, cancel := context.WithTimeout(t.Context(), time.Second*5)
@@ -84,16 +81,16 @@ func TestFSM_ReadWrite(t *testing.T) {
8481
return fsm.Read(ctx, func(ctx context.Context, cmd any) (pgfsm.Command, error) {
8582
switch msg := cmd.(type) {
8683
case *TestCommandA:
87-
handledA = true
84+
handledA.Store(true)
8885
return TestCommandB{Foo: msg.Foo + 1}, nil
8986
case *TestCommandB:
90-
handledB = true
87+
handledB.Store(true)
9188
return pgfsm.Batch(
9289
TestCommandC{Foo: msg.Foo + 1},
9390
TestCommandC{Foo: msg.Foo + 1},
9491
), nil
9592
case *TestCommandC:
96-
handledC = true
93+
handledC.Store(true)
9794
return nil, nil
9895
default:
9996
assert.Fail(t, "should be skipping unknown commands")
@@ -107,9 +104,9 @@ func TestFSM_ReadWrite(t *testing.T) {
107104
require.NoError(t, err)
108105
}
109106

110-
assert.True(t, handledA)
111-
assert.True(t, handledB)
112-
assert.True(t, handledC)
107+
assert.True(t, handledA.Load())
108+
assert.True(t, handledB.Load())
109+
assert.True(t, handledC.Load())
113110
}
114111

115112
func testDB(t *testing.T) *pgxpool.Pool {

0 commit comments

Comments
 (0)