Skip to content

Commit 1e2c0dd

Browse files
authoredJul 8, 2023
feat: support for enforcing the use of the leader db (#18)
1 parent 0065098 commit 1e2c0dd

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed
 

‎multiple.go

+18-2
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,11 @@ func (m *multipleSqlConn) containSelect(query string) bool {
167167
return false
168168
}
169169

170-
func (m *multipleSqlConn) getQueryDB(query string) queryDB {
170+
func (m *multipleSqlConn) getQueryDB(ctx context.Context, query string) queryDB {
171+
if forceLeaderFromContext(ctx) {
172+
return queryDB{conn: m.leader}
173+
}
174+
171175
if !m.enableFollower {
172176
return queryDB{conn: m.leader}
173177
}
@@ -242,7 +246,7 @@ func (m *multipleSqlConn) startSpanWithFollower(ctx context.Context, db int) (co
242246
}
243247

244248
func (m *multipleSqlConn) query(ctx context.Context, query string, do func(ctx context.Context, conn sqlx.SqlConn) error) error {
245-
db := m.getQueryDB(query)
249+
db := m.getQueryDB(ctx, query)
246250
var span oteltrace.Span
247251
if db.follower {
248252
ctx, span = m.startSpanWithFollower(ctx, db.followerDB)
@@ -297,3 +301,15 @@ func WithAccept(accept func(err error) bool) SqlOption {
297301
conn.accept = accept
298302
}
299303
}
304+
305+
type forceLeaderKey struct{}
306+
307+
func ForceLeaderContext(ctx context.Context) context.Context {
308+
return context.WithValue(ctx, forceLeaderKey{}, struct{}{})
309+
}
310+
311+
func forceLeaderFromContext(ctx context.Context) bool {
312+
value := ctx.Value(forceLeaderKey{})
313+
_, ok := value.(struct{})
314+
return ok
315+
}

‎multiple_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package sqlx
22

33
import (
4+
"context"
45
"database/sql/driver"
56
"testing"
67
"time"
@@ -52,3 +53,10 @@ func TestNewMultipleSqlConn(t *testing.T) {
5253
assert.NoError(t, err)
5354
assert.Equal(t, int64(1), rowsAffected)
5455
}
56+
57+
func TestForceLeaderContext(t *testing.T) {
58+
ctx := ForceLeaderContext(context.Background())
59+
assert.True(t, forceLeaderFromContext(ctx))
60+
61+
assert.False(t, forceLeaderFromContext(context.Background()))
62+
}

0 commit comments

Comments
 (0)
Please sign in to comment.