@@ -18,11 +18,12 @@ import (
18
18
"context"
19
19
"errors"
20
20
"fmt"
21
- "log/slog "
21
+ "sync "
22
22
"time"
23
23
24
24
"github.com/jackc/pgx/v5"
25
25
"github.com/jackc/pgx/v5/pgxpool"
26
+ "golang.org/x/sync/errgroup"
26
27
)
27
28
28
29
type (
@@ -79,7 +80,7 @@ func New(ctx context.Context, db *pgxpool.Pool, options ...Option) (*FSM, error)
79
80
o (& opts )
80
81
}
81
82
82
- if err := migrateUp (ctx , db , opts . logger ); err != nil {
83
+ if err := migrateUp (ctx , db ); err != nil {
83
84
return nil , err
84
85
}
85
86
@@ -113,19 +114,11 @@ func (fsm *FSM) Write(ctx context.Context, cmd Command) error {
113
114
switch command := cmd .(type ) {
114
115
case batchCommand :
115
116
for _ , cmd = range command {
116
- fsm .options .logger .
117
- With (slog .String ("command_kind" , cmd .Kind ())).
118
- InfoContext (ctx , "writing command" )
119
-
120
117
if err := insert (ctx , tx , fsm .options .encoder , cmd ); err != nil {
121
118
return err
122
119
}
123
120
}
124
121
default :
125
- fsm .options .logger .
126
- With (slog .String ("command_kind" , cmd .Kind ())).
127
- InfoContext (ctx , "writing command" )
128
-
129
122
return insert (ctx , tx , fsm .options .encoder , cmd )
130
123
}
131
124
@@ -174,76 +167,92 @@ var (
174
167
)
175
168
176
169
func (fsm * FSM ) next (ctx context.Context , h Handler ) error {
177
- fsm .options .logger .DebugContext (ctx , "checking for new commands" )
178
-
179
170
return transaction (ctx , fsm .db , func (ctx context.Context , tx pgx.Tx ) error {
180
- id , kind , data , err := next (ctx , tx )
171
+ records , err := next (ctx , tx , fsm . options . concurrency )
181
172
switch {
182
- case errors . Is ( err , pgx . ErrNoRows ) :
173
+ case len ( records ) == 0 :
183
174
return errNoCommands
184
175
case err != nil :
185
176
return err
186
177
}
187
178
188
- log := fsm .options .logger .With (
189
- slog .String ("command_kind" , kind ),
190
- slog .Int64 ("command_id" , id ),
191
- )
192
-
193
- factory , ok := fsm .commandFactories [kind ]
194
- switch {
195
- case ! ok && fsm .options .skipUnknownCommands :
196
- log .WarnContext (ctx , "skipping unknown command" )
197
- return remove (ctx , tx , id )
198
- case ! ok && ! fsm .options .skipUnknownCommands :
199
- return UnknownCommandError {Kind : kind }
200
- }
179
+ return fsm .handleCommands (ctx , tx , records , h )
180
+ })
181
+ }
201
182
202
- cmd := factory ()
203
- if err = fsm . options . encoder . Decode ( data , cmd ); err != nil {
204
- return fmt . Errorf ( "failed to decode command %d: %w" , id , err )
205
- }
183
+ func ( fsm * FSM ) handleCommands ( ctx context. Context , tx pgx. Tx , records [] record , h Handler ) error {
184
+ var mux sync. Mutex
185
+ commands := make ([] Command , 0 )
186
+ group , gCtx := errgroup . WithContext ( ctx )
206
187
207
- log .InfoContext (ctx , "handling command" )
208
- returned , err := h (ctx , cmd )
209
- if err != nil {
210
- log .ErrorContext (ctx , "error handling command" )
211
- return err
212
- }
188
+ for _ , r := range records {
189
+ group .Go (func () error {
190
+ cmd , err := fsm .handleCommand (gCtx , r , h )
191
+ switch {
192
+ case err != nil :
193
+ return err
194
+ case cmd == nil :
195
+ return nil
196
+ }
213
197
214
- if returned != nil {
215
- switch command := returned .(type ) {
198
+ mux .Lock ()
199
+ defer mux .Unlock ()
200
+ switch command := cmd .(type ) {
216
201
case batchCommand :
217
202
for _ , batched := range command {
218
- log .With (slog .String ("received_command_kind" , batched .Kind ())).
219
- InfoContext (ctx , "received additional command" )
220
-
221
- if err = insert (ctx , tx , fsm .options .encoder , batched ); err != nil {
222
- return err
223
- }
203
+ commands = append (commands , batched )
224
204
}
225
205
226
206
default :
227
- log .With (slog .String ("received_command_kind" , returned .Kind ())).
228
- InfoContext (ctx , "received additional command" )
229
-
230
- if err = insert (ctx , tx , fsm .options .encoder , command ); err != nil {
231
- return err
232
- }
207
+ commands = append (commands , command )
233
208
}
209
+
210
+ return nil
211
+ })
212
+ }
213
+
214
+ if err := group .Wait (); err != nil {
215
+ return err
216
+ }
217
+
218
+ for _ , cmd := range commands {
219
+ if err := insert (ctx , tx , fsm .options .encoder , cmd ); err != nil {
220
+ return err
234
221
}
222
+ }
235
223
236
- return remove (ctx , tx , id )
237
- })
224
+ return nil
225
+ }
226
+
227
+ func (fsm * FSM ) handleCommand (ctx context.Context , r record , h Handler ) (Command , error ) {
228
+ factory , ok := fsm .commandFactories [r .kind ]
229
+ switch {
230
+ case ! ok && fsm .options .skipUnknownCommands :
231
+ return nil , nil
232
+ case ! ok && ! fsm .options .skipUnknownCommands :
233
+ return nil , UnknownCommandError {Kind : r .kind }
234
+ }
235
+
236
+ cmd := factory ()
237
+ if err := fsm .options .encoder .Decode (r .data , cmd ); err != nil {
238
+ return nil , fmt .Errorf ("failed to decode command %d: %w" , r .id , err )
239
+ }
240
+
241
+ returned , err := h (ctx , cmd )
242
+ if err != nil {
243
+ return nil , err
244
+ }
245
+
246
+ return returned , nil
238
247
}
239
248
240
249
type (
241
250
options struct {
242
251
skipUnknownCommands bool
243
252
encoder Encoding
244
- logger * slog.Logger
245
253
minPollInterval time.Duration
246
254
maxPollInterval time.Duration
255
+ concurrency uint
247
256
}
248
257
249
258
// The Option type is a function that can modify the behaviour of the FSM.
@@ -254,9 +263,9 @@ func defaultOptions() options {
254
263
return options {
255
264
skipUnknownCommands : false ,
256
265
encoder : & JSON {},
257
- logger : slog .New (slog .DiscardHandler ),
258
266
minPollInterval : time .Millisecond ,
259
267
maxPollInterval : time .Second * 5 ,
268
+ concurrency : 1 ,
260
269
}
261
270
}
262
271
@@ -278,6 +287,14 @@ func UseEncoding(e Encoding) Option {
278
287
return func (o * options ) { o .encoder = e }
279
288
}
280
289
290
+ // SetConcurrency is an Option implementation that changes the number of commands that can be processed at the same time.
291
+ // It defaults to 1. Concurrent commands are processed within the same transaction. This means that within a batch of
292
+ // commands, should any fail then all commands within the batch will be returned to the database and any subsequent
293
+ // commands rolled back.
294
+ func SetConcurrency (concurrency uint ) Option {
295
+ return func (o * options ) { o .concurrency = concurrency }
296
+ }
297
+
281
298
// PollInterval is an Option implementation that configures the minimum and maximum frequency at which Command implementations
282
299
// will be read from the database. Each time the FSM checks for commands and finds none, it will half the frequency at
283
300
// which it checks up to the maximum value. This is done to prevent excessive load on the database at times where there
@@ -289,12 +306,6 @@ func PollInterval(min, max time.Duration) Option {
289
306
}
290
307
}
291
308
292
- // UseLogger is an Option implementation that modifies the logger used by the FSM. By default, the FSM uses
293
- // slog.DiscardHandler and will not write any logs.
294
- func UseLogger (l * slog.Logger ) Option {
295
- return func (o * options ) { o .logger = l .WithGroup ("pgfsm" ) }
296
- }
297
-
298
309
type (
299
310
batchCommand []Command
300
311
)
0 commit comments