diff --git a/go/vt/vtgate/planbuilder/operators/union.go b/go/vt/vtgate/planbuilder/operators/union.go index 0c7ecc97bc3..8fd06adb219 100644 --- a/go/vt/vtgate/planbuilder/operators/union.go +++ b/go/vt/vtgate/planbuilder/operators/union.go @@ -23,6 +23,7 @@ import ( "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/predicates" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) @@ -103,6 +104,11 @@ func (u *Union) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Ex offsets[ae.ColumnName()] = i } + if jp, ok := expr.(*predicates.JoinPredicate); ok { + expr = jp.Current() + ctx.PredTracker.Skip(jp.ID) + } + needsFilter, exprPerSource := u.predicatePerSource(expr, offsets) if needsFilter { return newFilter(u, expr) @@ -118,6 +124,7 @@ func (u *Union) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Ex func (u *Union) predicatePerSource(expr sqlparser.Expr, offsets map[string]int) (bool, []sqlparser.Expr) { needsFilter := false exprPerSource := make([]sqlparser.Expr, len(u.Sources)) + for i := range u.Sources { predicate := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { col, ok := cursor.Node().(*sqlparser.ColName) @@ -137,11 +144,13 @@ func (u *Union) predicatePerSource(expr sqlparser.Expr, offsets map[string]int) if !ok { panic(vterrors.VT09015()) } + cursor.Replace(ae.Expr) }, nil).(sqlparser.Expr) exprPerSource[i] = predicate } + return needsFilter, exprPerSource } diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index 0cf8defd671..7555662b078 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -2086,5 +2086,79 @@ "user.user" ] } + }, + { + "comment": "join with derived table containing UNION with aliased columns across different keyspaces", + "query": "select u.id from user as u join (select id as pid from user where id = 1 union select id as pid from unsharded where id = 1) as i on u.id = i.pid", + "plan": { + "Type": "Join", + "QueryType": "SELECT", + "Original": "select u.id from user as u join (select id as pid from user where id = 1 union select id as pid from unsharded where id = 1) as i on u.id = i.pid", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0", + "JoinVars": { + "u_id": 0 + }, + "TableName": "`user`_`user`_unsharded", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.id from `user` as u where 1 != 1", + "Query": "select u.id from `user` as u", + "Table": "`user`" + }, + { + "OperatorType": "Distinct", + "Collations": [ + "(0:1)" + ], + "Inputs": [ + { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select dt.c0 as pid, weight_string(dt.c0) from (select id as pid from `user` where 1 != 1) as dt(c0) where 1 != 1", + "Query": "select dt.c0 as pid, weight_string(dt.c0) from (select distinct id as pid from `user` where id = 1 and id = :u_id) as dt(c0)", + "Table": "`user`", + "Values": [ + ":u_id" + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select dt.c0 as pid, weight_string(dt.c0) from (select id as pid from unsharded where 1 != 1) as dt(c0) where 1 != 1", + "Query": "select dt.c0 as pid, weight_string(dt.c0) from (select distinct id as pid from unsharded where id = 1 and id = :u_id) as dt(c0)", + "Table": "unsharded" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "main.unsharded", + "user.user" + ] + } } ]