diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index fd206590d9a..54653d52b3e 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -60,6 +60,9 @@ type Route struct { // Query specifies the query to be executed. Query string + // QueryStatement is the parsed AST of Query + QueryStatement sqlparser.Statement + // FieldQuery specifies the query to be executed for a GetFieldInfo request. FieldQuery string @@ -512,6 +515,18 @@ func (route *Route) executeWarmingReplicaRead(ctx context.Context, vcursor VCurs return } + // Remove FOR UPDATE locks for warming reads if present + warmingQueries := queries + if modifiedQuery, ok := removeForUpdateLocks(route.QueryStatement); ok { + warmingQueries = make([]*querypb.BoundQuery, len(queries)) + for i, query := range queries { + warmingQueries[i] = &querypb.BoundQuery{ + Sql: modifiedQuery, + BindVariables: query.BindVariables, + } + } + } + replicaVCursor := vcursor.CloneForReplicaWarming(ctx) warmingReadsChannel := vcursor.GetWarmingReadsChannel() @@ -527,7 +542,7 @@ func (route *Route) executeWarmingReplicaRead(ctx context.Context, vcursor VCurs return } - _, errs := replicaVCursor.ExecuteMultiShard(ctx, route, rss, queries, false /*rollbackOnError*/, false /*canAutocommit*/, route.FetchLastInsertID) + _, errs := replicaVCursor.ExecuteMultiShard(ctx, route, rss, warmingQueries, false /*rollbackOnError*/, false /*canAutocommit*/, route.FetchLastInsertID) if len(errs) > 0 { log.Warningf("Failed to execute warming replica read: %v", errs) } else { @@ -538,3 +553,23 @@ func (route *Route) executeWarmingReplicaRead(ctx context.Context, vcursor VCurs log.Warning("Failed to execute warming replica read as pool is full") } } + +func removeForUpdateLocks(stmt sqlparser.Statement) (string, bool) { + sel, ok := stmt.(*sqlparser.Select) + if !ok { + return "", false + } + + // Check if this is a FOR UPDATE query + if sel.Lock != sqlparser.ForUpdateLock && + sel.Lock != sqlparser.ForUpdateLockNoWait && + sel.Lock != sqlparser.ForUpdateLockSkipLocked { + return "", false + } + + // Remove the lock clause + sel.Lock = sqlparser.NoLock + + // Convert back to SQL string + return sqlparser.String(sel), true +} diff --git a/go/vt/vtgate/engine/route_warming_reads_test.go b/go/vt/vtgate/engine/route_warming_reads_test.go new file mode 100644 index 00000000000..f0dadeb4c13 --- /dev/null +++ b/go/vt/vtgate/engine/route_warming_reads_test.go @@ -0,0 +1,142 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/srvtopo" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/vindexes" + + querypb "vitess.io/vitess/go/vt/proto/query" +) + +type warmingReadsVCursor struct { + *loggingVCursor + warmingReadsPercent int + warmingReadsChannel chan bool + warmingReadsExecuteFunc func(context.Context, Primitive, []*srvtopo.ResolvedShard, []*querypb.BoundQuery, bool, bool) +} + +func (vc *warmingReadsVCursor) GetWarmingReadsPercent() int { + return vc.warmingReadsPercent +} + +func (vc *warmingReadsVCursor) GetWarmingReadsChannel() chan bool { + return vc.warmingReadsChannel +} + +func (vc *warmingReadsVCursor) CloneForReplicaWarming(ctx context.Context) VCursor { + clone := &warmingReadsVCursor{ + loggingVCursor: vc.loggingVCursor, + warmingReadsPercent: vc.warmingReadsPercent, + warmingReadsChannel: vc.warmingReadsChannel, + warmingReadsExecuteFunc: vc.warmingReadsExecuteFunc, + } + clone.onExecuteMultiShardFn = vc.warmingReadsExecuteFunc + return clone +} + +func TestWarmingReadsSkipsForUpdate(t *testing.T) { + vindex, _ := vindexes.CreateVindex("hash", "", nil) + testCases := []struct { + name string + query string + expectedWarmingQuery string + }{ + { + name: "SELECT FOR UPDATE", + query: "SELECT * FROM users WHERE id = 1 FOR UPDATE", + expectedWarmingQuery: "select * from users where id = 1", + }, + { + name: "SELECT FOR UPDATE mixed case", + query: "SELECT * FROM users WHERE id = 1 FoR UpDaTe", + expectedWarmingQuery: "select * from users where id = 1", + }, + { + name: "SELECT FOR UPDATE with extra spaces", + query: "SELECT * FROM users WHERE id = 1 FOR UPDATE", + expectedWarmingQuery: "select * from users where id = 1", + }, + { + name: "SELECT FOR UPDATE with comment", + query: "SELECT * FROM users WHERE id = 1 FOR /* comment */ UPDATE", + expectedWarmingQuery: "select * from users where id = 1", + }, + { + name: "Regular SELECT", + query: "SELECT * FROM users WHERE id = 1", + expectedWarmingQuery: "SELECT * FROM users WHERE id = 1", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + route := NewRoute( + EqualUnique, + &vindexes.Keyspace{ + Name: "ks", + Sharded: true, + }, + tc.query, + "dummy_select_field", + ) + // Parse and set QueryStatement to match how Routes are created in production + parser, _ := sqlparser.NewTestParser().Parse(tc.query) + route.QueryStatement = parser + route.Vindex = vindex.(vindexes.SingleColumn) + route.Values = []evalengine.Expr{ + evalengine.NewLiteralInt(1), + } + + var warmingReadExecuted atomic.Bool + var capturedQuery string + vc := &warmingReadsVCursor{ + loggingVCursor: &loggingVCursor{ + shards: []string{"-20", "20-"}, + results: []*sqltypes.Result{defaultSelectResult}, + }, + warmingReadsPercent: 100, + warmingReadsChannel: make(chan bool, 1), + } + vc.warmingReadsExecuteFunc = func(ctx context.Context, primitive Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, rollbackOnError, canAutocommit bool) { + if len(queries) > 0 { + capturedQuery = queries[0].Sql + } + warmingReadExecuted.Store(true) + } + + _, err := route.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) + require.NoError(t, err) + + require.Eventually(t, func() bool { + return warmingReadExecuted.Load() + }, time.Second, 10*time.Millisecond, "warming read should be executed") + + require.Equal(t, tc.expectedWarmingQuery, capturedQuery, "warming read query should match expected") + }) + } +} diff --git a/go/vt/vtgate/planbuilder/route.go b/go/vt/vtgate/planbuilder/route.go index 167bfa6e191..e2153e65a6f 100644 --- a/go/vt/vtgate/planbuilder/route.go +++ b/go/vt/vtgate/planbuilder/route.go @@ -29,6 +29,7 @@ import ( func WireupRoute(ctx *plancontext.PlanningContext, eroute *engine.Route, sel sqlparser.SelectStatement) (engine.Primitive, error) { // prepare the queries we will pass down eroute.Query = sqlparser.String(sel) + eroute.QueryStatement = sel buffer := sqlparser.NewTrackedBuffer(sqlparser.FormatImpossibleQuery) node := buffer.WriteNode(sel) eroute.FieldQuery = node.ParsedQuery().Query