Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion go/vt/vtgate/engine/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ type Route struct {
// Query specifies the query to be executed.
Query string

// QueryStatement is the parsed AST of Query
QueryStatement sqlparser.Statement

// FieldQuery specifies the query to be executed for a GetFieldInfo request.
FieldQuery string

Expand Down Expand Up @@ -512,6 +515,18 @@ func (route *Route) executeWarmingReplicaRead(ctx context.Context, vcursor VCurs
return
}

// Remove FOR UPDATE locks for warming reads if present
warmingQueries := queries
if modifiedQuery, ok := removeForUpdateLocks(route.QueryStatement); ok {
warmingQueries = make([]*querypb.BoundQuery, len(queries))
for i, query := range queries {
warmingQueries[i] = &querypb.BoundQuery{
Sql: modifiedQuery,
BindVariables: query.BindVariables,
}
}
}

replicaVCursor := vcursor.CloneForReplicaWarming(ctx)
warmingReadsChannel := vcursor.GetWarmingReadsChannel()

Expand All @@ -527,7 +542,7 @@ func (route *Route) executeWarmingReplicaRead(ctx context.Context, vcursor VCurs
return
}

_, errs := replicaVCursor.ExecuteMultiShard(ctx, route, rss, queries, false /*rollbackOnError*/, false /*canAutocommit*/, route.FetchLastInsertID)
_, errs := replicaVCursor.ExecuteMultiShard(ctx, route, rss, warmingQueries, false /*rollbackOnError*/, false /*canAutocommit*/, route.FetchLastInsertID)
if len(errs) > 0 {
log.Warningf("Failed to execute warming replica read: %v", errs)
} else {
Expand All @@ -538,3 +553,23 @@ func (route *Route) executeWarmingReplicaRead(ctx context.Context, vcursor VCurs
log.Warning("Failed to execute warming replica read as pool is full")
}
}

func removeForUpdateLocks(stmt sqlparser.Statement) (string, bool) {
sel, ok := stmt.(*sqlparser.Select)
if !ok {
return "", false
}

// Check if this is a FOR UPDATE query
if sel.Lock != sqlparser.ForUpdateLock &&
sel.Lock != sqlparser.ForUpdateLockNoWait &&
sel.Lock != sqlparser.ForUpdateLockSkipLocked {
return "", false
}

// Remove the lock clause
sel.Lock = sqlparser.NoLock

// Convert back to SQL string
return sqlparser.String(sel), true
}
142 changes: 142 additions & 0 deletions go/vt/vtgate/engine/route_warming_reads_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
Copyright 2024 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package engine

import (
"context"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/srvtopo"
"vitess.io/vitess/go/vt/vtgate/evalengine"
"vitess.io/vitess/go/vt/vtgate/vindexes"

querypb "vitess.io/vitess/go/vt/proto/query"
)

type warmingReadsVCursor struct {
*loggingVCursor
warmingReadsPercent int
warmingReadsChannel chan bool
warmingReadsExecuteFunc func(context.Context, Primitive, []*srvtopo.ResolvedShard, []*querypb.BoundQuery, bool, bool)
}

func (vc *warmingReadsVCursor) GetWarmingReadsPercent() int {
return vc.warmingReadsPercent
}

func (vc *warmingReadsVCursor) GetWarmingReadsChannel() chan bool {
return vc.warmingReadsChannel
}

func (vc *warmingReadsVCursor) CloneForReplicaWarming(ctx context.Context) VCursor {
clone := &warmingReadsVCursor{
loggingVCursor: vc.loggingVCursor,
warmingReadsPercent: vc.warmingReadsPercent,
warmingReadsChannel: vc.warmingReadsChannel,
warmingReadsExecuteFunc: vc.warmingReadsExecuteFunc,
}
clone.onExecuteMultiShardFn = vc.warmingReadsExecuteFunc
return clone
}

func TestWarmingReadsSkipsForUpdate(t *testing.T) {
vindex, _ := vindexes.CreateVindex("hash", "", nil)
testCases := []struct {
name string
query string
expectedWarmingQuery string
}{
{
name: "SELECT FOR UPDATE",
query: "SELECT * FROM users WHERE id = 1 FOR UPDATE",
expectedWarmingQuery: "select * from users where id = 1",
},
{
name: "SELECT FOR UPDATE mixed case",
query: "SELECT * FROM users WHERE id = 1 FoR UpDaTe",
expectedWarmingQuery: "select * from users where id = 1",
},
{
name: "SELECT FOR UPDATE with extra spaces",
query: "SELECT * FROM users WHERE id = 1 FOR UPDATE",
expectedWarmingQuery: "select * from users where id = 1",
},
{
name: "SELECT FOR UPDATE with comment",
query: "SELECT * FROM users WHERE id = 1 FOR /* comment */ UPDATE",
expectedWarmingQuery: "select * from users where id = 1",
},
{
name: "Regular SELECT",
query: "SELECT * FROM users WHERE id = 1",
expectedWarmingQuery: "SELECT * FROM users WHERE id = 1",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
route := NewRoute(
EqualUnique,
&vindexes.Keyspace{
Name: "ks",
Sharded: true,
},
tc.query,
"dummy_select_field",
)
// Parse and set QueryStatement to match how Routes are created in production
parser, _ := sqlparser.NewTestParser().Parse(tc.query)
route.QueryStatement = parser
route.Vindex = vindex.(vindexes.SingleColumn)
route.Values = []evalengine.Expr{
evalengine.NewLiteralInt(1),
}

var warmingReadExecuted atomic.Bool
var capturedQuery string
vc := &warmingReadsVCursor{
loggingVCursor: &loggingVCursor{
shards: []string{"-20", "20-"},
results: []*sqltypes.Result{defaultSelectResult},
},
warmingReadsPercent: 100,
warmingReadsChannel: make(chan bool, 1),
}
vc.warmingReadsExecuteFunc = func(ctx context.Context, primitive Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, rollbackOnError, canAutocommit bool) {
if len(queries) > 0 {
capturedQuery = queries[0].Sql
}
warmingReadExecuted.Store(true)
}

_, err := route.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false)
require.NoError(t, err)

require.Eventually(t, func() bool {
return warmingReadExecuted.Load()
}, time.Second, 10*time.Millisecond, "warming read should be executed")

require.Equal(t, tc.expectedWarmingQuery, capturedQuery, "warming read query should match expected")
})
}
}
1 change: 1 addition & 0 deletions go/vt/vtgate/planbuilder/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
func WireupRoute(ctx *plancontext.PlanningContext, eroute *engine.Route, sel sqlparser.SelectStatement) (engine.Primitive, error) {
// prepare the queries we will pass down
eroute.Query = sqlparser.String(sel)
eroute.QueryStatement = sel
buffer := sqlparser.NewTrackedBuffer(sqlparser.FormatImpossibleQuery)
node := buffer.WriteNode(sel)
eroute.FieldQuery = node.ParsedQuery().Query
Expand Down
Loading