Skip to content

Commit ffcd402

Browse files
committed
add tx retry test
1 parent 02fa35e commit ffcd402

File tree

1 file changed

+156
-0
lines changed

1 file changed

+156
-0
lines changed

tx_test.go

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
package pgsql_test
22

33
import (
4+
"context"
45
"database/sql"
6+
"database/sql/driver"
7+
"errors"
58
"fmt"
69
"log"
710
"math/rand"
811
"sync"
912
"testing"
13+
"time"
1014

1115
"github.com/acoshift/pgsql"
1216
)
@@ -148,3 +152,155 @@ func TestTx(t *testing.T) {
148152
t.Fatalf("expected sum all value to be 0; got %d", result)
149153
}
150154
}
155+
156+
func TestTxRetryWithBackoff(t *testing.T) {
157+
t.Parallel()
158+
159+
t.Run("Backoff when serialization failure occurs", func(t *testing.T) {
160+
t.Parallel()
161+
162+
attemptCount := 0
163+
opts := &pgsql.TxOptions{
164+
MaxAttempts: 3,
165+
BackoffDelayFunc: func(attempt int) time.Duration {
166+
attemptCount++
167+
return 1
168+
},
169+
}
170+
171+
pgsql.RunInTxContext(context.Background(), sql.OpenDB(&fakeConnector{}), opts, func(*sql.Tx) error {
172+
return &mockSerializationFailureError{}
173+
})
174+
175+
if attemptCount != opts.MaxAttempts-1 {
176+
t.Fatalf("expected BackoffDelayFunc to be called %d times, got %d", opts.MaxAttempts, attemptCount)
177+
}
178+
})
179+
180+
t.Run("Successful After Multiple Failures", func(t *testing.T) {
181+
t.Parallel()
182+
183+
failCount := 0
184+
maxFailures := 3
185+
opts := &pgsql.TxOptions{
186+
MaxAttempts: maxFailures + 1,
187+
BackoffDelayFunc: func(attempt int) time.Duration {
188+
return 1
189+
},
190+
}
191+
192+
err := pgsql.RunInTxContext(context.Background(), sql.OpenDB(&fakeConnector{}), opts, func(tx *sql.Tx) error {
193+
if failCount < maxFailures {
194+
failCount++
195+
return &mockSerializationFailureError{}
196+
}
197+
return nil
198+
})
199+
if err != nil {
200+
t.Fatalf("expected success after failures, got error: %v", err)
201+
}
202+
if failCount != maxFailures {
203+
t.Fatalf("expected %d failures before success, got %d", maxFailures, failCount)
204+
}
205+
})
206+
207+
t.Run("Context Cancellation", func(t *testing.T) {
208+
t.Parallel()
209+
210+
opts := &pgsql.TxOptions{
211+
MaxAttempts: 3,
212+
BackoffDelayFunc: func(attempt int) time.Duration {
213+
return 1
214+
},
215+
}
216+
217+
ctx, cancel := context.WithCancel(context.Background())
218+
cancel() // Cancel the context immediately
219+
220+
err := pgsql.RunInTxContext(ctx, sql.OpenDB(&fakeConnector{}), opts, func(*sql.Tx) error {
221+
return &mockSerializationFailureError{}
222+
})
223+
if !errors.Is(err, context.Canceled) {
224+
t.Fatalf("expected context.Canceled error, got %v", err)
225+
}
226+
})
227+
228+
t.Run("Max Attempts Reached", func(t *testing.T) {
229+
t.Parallel()
230+
231+
attemptCount := 0
232+
opts := &pgsql.TxOptions{
233+
MaxAttempts: 3,
234+
BackoffDelayFunc: func(attempt int) time.Duration {
235+
return 1
236+
},
237+
}
238+
239+
err := pgsql.RunInTxContext(context.Background(), sql.OpenDB(&fakeConnector{}), opts, func(*sql.Tx) error {
240+
attemptCount++
241+
return &mockSerializationFailureError{}
242+
})
243+
if errors.As(err, &mockSerializationFailureError{}) {
244+
t.Fatal("expected an error when max attempts reached")
245+
}
246+
if attemptCount != opts.MaxAttempts {
247+
t.Fatalf("expected %d attempts, got %d", opts.MaxAttempts, attemptCount)
248+
}
249+
})
250+
}
251+
252+
type fakeConnector struct {
253+
driver.Connector
254+
}
255+
256+
func (c *fakeConnector) Connect(ctx context.Context) (driver.Conn, error) {
257+
return &fakeConn{}, nil
258+
}
259+
260+
func (c *fakeConnector) Driver() driver.Driver {
261+
panic("not implemented")
262+
}
263+
264+
type fakeConn struct {
265+
driver.Conn
266+
}
267+
268+
func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
269+
return nil, fmt.Errorf("not implemented")
270+
}
271+
272+
func (c *fakeConn) Close() error {
273+
return nil
274+
}
275+
276+
func (c *fakeConn) Begin() (driver.Tx, error) {
277+
return &fakeTx{}, nil
278+
}
279+
280+
var _ driver.ConnBeginTx = (*fakeConn)(nil)
281+
282+
func (c *fakeConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
283+
return &fakeTx{}, nil
284+
}
285+
286+
type fakeTx struct {
287+
driver.Tx
288+
}
289+
290+
func (tx *fakeTx) Commit() error {
291+
return nil
292+
}
293+
294+
func (tx *fakeTx) Rollback() error {
295+
return nil
296+
}
297+
298+
type mockSerializationFailureError struct{}
299+
300+
func (e mockSerializationFailureError) Error() string {
301+
return "mock serialization failure error"
302+
}
303+
304+
func (e mockSerializationFailureError) SQLState() string {
305+
return "40001" // SQLSTATE code for serialization failure
306+
}

0 commit comments

Comments
 (0)