Skip to content

Commit

Permalink
session.DB apply context to fallback db
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch committed Feb 3, 2025
1 parent 27d1cbd commit 0c2bd2f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
7 changes: 4 additions & 3 deletions util/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,13 @@ func (s Gorm) nestedTransaction(tx *gorm.DB, f func(context.Context) error) erro
return err
}

// DB returns the Gorm instance stored in the given context. Returns the given fallback
// if no Gorm DB could be found in the context.
// DB returns the Gorm instance stored in the given context.
// If no Gorm DB could be found in the context, calls `fallback.WithContext` and
// return the result.
func DB(ctx context.Context, fallback *gorm.DB) *gorm.DB {
db := ctx.Value(dbKey{})
if db == nil {
return fallback
return fallback.WithContext(ctx)
}
return db.(*gorm.DB)
}
37 changes: 24 additions & 13 deletions util/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ type testDialector struct {

savepoint string
rolledbackTo string

id string
}

func (d *testDialector) SavePoint(_ *gorm.DB, name string) error {
Expand Down Expand Up @@ -384,36 +386,45 @@ func TestGormSession(t *testing.T) {
})

t.Run("DB", func(t *testing.T) {
db, err := database.NewFromDialector(cfg, nil, &testDialector{})
db, err := database.NewFromDialector(cfg, nil, &testDialector{id: "in_context"})
require.NoError(t, err)
fallback, err := database.NewFromDialector(cfg, nil, &testDialector{id: "fallback"})
require.NoError(t, err)
fallback := &gorm.DB{}

valueCtx := context.WithValue(context.Background(), testKey{}, "testvalue")
cases := []struct {
ctx context.Context
expect *gorm.DB
expect func(t *testing.T, result *gorm.DB)
desc string
}{
{
desc: "missing_from_context",
ctx: context.Background(),
expect: fallback,
desc: "missing_from_context",
ctx: context.Background(),
expect: func(t *testing.T, result *gorm.DB) {
assert.Equal(t, fallback.Dialector.(*testDialector).id, result.Dialector.(*testDialector).id)
},
},
{
desc: "fallback",
ctx: context.Background(),
expect: fallback,
desc: "fallback",
ctx: valueCtx,
expect: func(t *testing.T, result *gorm.DB) {
assert.Equal(t, fallback.Dialector.(*testDialector).id, result.Dialector.(*testDialector).id)
assert.Equal(t, "testvalue", result.Statement.Context.Value(testKey{}))
},
},
{
desc: "found",
ctx: context.WithValue(context.Background(), dbKey{}, db),
expect: db,
desc: "found",
ctx: context.WithValue(context.Background(), dbKey{}, db),
expect: func(t *testing.T, result *gorm.DB) {
assert.Same(t, db, result)
},
},
}

for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
db := DB(c.ctx, fallback)
assert.Equal(t, c.expect, db)
c.expect(t, db)
})
}
})
Expand Down

0 comments on commit 0c2bd2f

Please sign in to comment.