|
1 | 1 | package pgsql_test
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "context" |
4 | 5 | "database/sql"
|
| 6 | + "database/sql/driver" |
| 7 | + "errors" |
5 | 8 | "fmt"
|
6 | 9 | "log"
|
7 | 10 | "math/rand"
|
8 | 11 | "sync"
|
9 | 12 | "testing"
|
| 13 | + "time" |
10 | 14 |
|
11 | 15 | "github.com/acoshift/pgsql"
|
12 | 16 | )
|
@@ -148,3 +152,155 @@ func TestTx(t *testing.T) {
|
148 | 152 | t.Fatalf("expected sum all value to be 0; got %d", result)
|
149 | 153 | }
|
150 | 154 | }
|
| 155 | + |
| 156 | +func TestTxRetryWithBackoff(t *testing.T) { |
| 157 | + t.Parallel() |
| 158 | + |
| 159 | + t.Run("Backoff when serialization failure occurs", func(t *testing.T) { |
| 160 | + t.Parallel() |
| 161 | + |
| 162 | + attemptCount := 0 |
| 163 | + opts := &pgsql.TxOptions{ |
| 164 | + MaxAttempts: 3, |
| 165 | + BackoffDelayFunc: func(attempt int) time.Duration { |
| 166 | + attemptCount++ |
| 167 | + return 1 |
| 168 | + }, |
| 169 | + } |
| 170 | + |
| 171 | + pgsql.RunInTxContext(context.Background(), sql.OpenDB(&fakeConnector{}), opts, func(*sql.Tx) error { |
| 172 | + return &mockSerializationFailureError{} |
| 173 | + }) |
| 174 | + |
| 175 | + if attemptCount != opts.MaxAttempts-1 { |
| 176 | + t.Fatalf("expected BackoffDelayFunc to be called %d times, got %d", opts.MaxAttempts, attemptCount) |
| 177 | + } |
| 178 | + }) |
| 179 | + |
| 180 | + t.Run("Successful After Multiple Failures", func(t *testing.T) { |
| 181 | + t.Parallel() |
| 182 | + |
| 183 | + failCount := 0 |
| 184 | + maxFailures := 3 |
| 185 | + opts := &pgsql.TxOptions{ |
| 186 | + MaxAttempts: maxFailures + 1, |
| 187 | + BackoffDelayFunc: func(attempt int) time.Duration { |
| 188 | + return 1 |
| 189 | + }, |
| 190 | + } |
| 191 | + |
| 192 | + err := pgsql.RunInTxContext(context.Background(), sql.OpenDB(&fakeConnector{}), opts, func(tx *sql.Tx) error { |
| 193 | + if failCount < maxFailures { |
| 194 | + failCount++ |
| 195 | + return &mockSerializationFailureError{} |
| 196 | + } |
| 197 | + return nil |
| 198 | + }) |
| 199 | + if err != nil { |
| 200 | + t.Fatalf("expected success after failures, got error: %v", err) |
| 201 | + } |
| 202 | + if failCount != maxFailures { |
| 203 | + t.Fatalf("expected %d failures before success, got %d", maxFailures, failCount) |
| 204 | + } |
| 205 | + }) |
| 206 | + |
| 207 | + t.Run("Context Cancellation", func(t *testing.T) { |
| 208 | + t.Parallel() |
| 209 | + |
| 210 | + opts := &pgsql.TxOptions{ |
| 211 | + MaxAttempts: 3, |
| 212 | + BackoffDelayFunc: func(attempt int) time.Duration { |
| 213 | + return 1 |
| 214 | + }, |
| 215 | + } |
| 216 | + |
| 217 | + ctx, cancel := context.WithCancel(context.Background()) |
| 218 | + cancel() // Cancel the context immediately |
| 219 | + |
| 220 | + err := pgsql.RunInTxContext(ctx, sql.OpenDB(&fakeConnector{}), opts, func(*sql.Tx) error { |
| 221 | + return &mockSerializationFailureError{} |
| 222 | + }) |
| 223 | + if !errors.Is(err, context.Canceled) { |
| 224 | + t.Fatalf("expected context.Canceled error, got %v", err) |
| 225 | + } |
| 226 | + }) |
| 227 | + |
| 228 | + t.Run("Max Attempts Reached", func(t *testing.T) { |
| 229 | + t.Parallel() |
| 230 | + |
| 231 | + attemptCount := 0 |
| 232 | + opts := &pgsql.TxOptions{ |
| 233 | + MaxAttempts: 3, |
| 234 | + BackoffDelayFunc: func(attempt int) time.Duration { |
| 235 | + return 1 |
| 236 | + }, |
| 237 | + } |
| 238 | + |
| 239 | + err := pgsql.RunInTxContext(context.Background(), sql.OpenDB(&fakeConnector{}), opts, func(*sql.Tx) error { |
| 240 | + attemptCount++ |
| 241 | + return &mockSerializationFailureError{} |
| 242 | + }) |
| 243 | + if errors.As(err, &mockSerializationFailureError{}) { |
| 244 | + t.Fatal("expected an error when max attempts reached") |
| 245 | + } |
| 246 | + if attemptCount != opts.MaxAttempts { |
| 247 | + t.Fatalf("expected %d attempts, got %d", opts.MaxAttempts, attemptCount) |
| 248 | + } |
| 249 | + }) |
| 250 | +} |
| 251 | + |
| 252 | +type fakeConnector struct { |
| 253 | + driver.Connector |
| 254 | +} |
| 255 | + |
| 256 | +func (c *fakeConnector) Connect(ctx context.Context) (driver.Conn, error) { |
| 257 | + return &fakeConn{}, nil |
| 258 | +} |
| 259 | + |
| 260 | +func (c *fakeConnector) Driver() driver.Driver { |
| 261 | + panic("not implemented") |
| 262 | +} |
| 263 | + |
| 264 | +type fakeConn struct { |
| 265 | + driver.Conn |
| 266 | +} |
| 267 | + |
| 268 | +func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { |
| 269 | + return nil, fmt.Errorf("not implemented") |
| 270 | +} |
| 271 | + |
| 272 | +func (c *fakeConn) Close() error { |
| 273 | + return nil |
| 274 | +} |
| 275 | + |
| 276 | +func (c *fakeConn) Begin() (driver.Tx, error) { |
| 277 | + return &fakeTx{}, nil |
| 278 | +} |
| 279 | + |
| 280 | +var _ driver.ConnBeginTx = (*fakeConn)(nil) |
| 281 | + |
| 282 | +func (c *fakeConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { |
| 283 | + return &fakeTx{}, nil |
| 284 | +} |
| 285 | + |
| 286 | +type fakeTx struct { |
| 287 | + driver.Tx |
| 288 | +} |
| 289 | + |
| 290 | +func (tx *fakeTx) Commit() error { |
| 291 | + return nil |
| 292 | +} |
| 293 | + |
| 294 | +func (tx *fakeTx) Rollback() error { |
| 295 | + return nil |
| 296 | +} |
| 297 | + |
| 298 | +type mockSerializationFailureError struct{} |
| 299 | + |
| 300 | +func (e mockSerializationFailureError) Error() string { |
| 301 | + return "mock serialization failure error" |
| 302 | +} |
| 303 | + |
| 304 | +func (e mockSerializationFailureError) SQLState() string { |
| 305 | + return "40001" // SQLSTATE code for serialization failure |
| 306 | +} |
0 commit comments