@@ -167,7 +167,11 @@ func (m *multipleSqlConn) containSelect(query string) bool {
167
167
return false
168
168
}
169
169
170
- func (m * multipleSqlConn ) getQueryDB (query string ) queryDB {
170
+ func (m * multipleSqlConn ) getQueryDB (ctx context.Context , query string ) queryDB {
171
+ if forceLeaderFromContext (ctx ) {
172
+ return queryDB {conn : m .leader }
173
+ }
174
+
171
175
if ! m .enableFollower {
172
176
return queryDB {conn : m .leader }
173
177
}
@@ -242,7 +246,7 @@ func (m *multipleSqlConn) startSpanWithFollower(ctx context.Context, db int) (co
242
246
}
243
247
244
248
func (m * multipleSqlConn ) query (ctx context.Context , query string , do func (ctx context.Context , conn sqlx.SqlConn ) error ) error {
245
- db := m .getQueryDB (query )
249
+ db := m .getQueryDB (ctx , query )
246
250
var span oteltrace.Span
247
251
if db .follower {
248
252
ctx , span = m .startSpanWithFollower (ctx , db .followerDB )
@@ -297,3 +301,15 @@ func WithAccept(accept func(err error) bool) SqlOption {
297
301
conn .accept = accept
298
302
}
299
303
}
304
+
305
+ type forceLeaderKey struct {}
306
+
307
+ func ForceLeaderContext (ctx context.Context ) context.Context {
308
+ return context .WithValue (ctx , forceLeaderKey {}, struct {}{})
309
+ }
310
+
311
+ func forceLeaderFromContext (ctx context.Context ) bool {
312
+ value := ctx .Value (forceLeaderKey {})
313
+ _ , ok := value .(struct {})
314
+ return ok
315
+ }
0 commit comments