Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -666,15 +666,15 @@ func TestDistinctAggregation(t *testing.T) {
expectedErr string
minVersion int
}{{
query: `SELECT COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`,
expectedErr: "VT12001: unsupported: only one DISTINCT aggregation is allowed in a SELECT: sum(distinct shardkey) (errno 1235) (sqlstate 42000)",
// Multiple distinct aggregations with different columns - now supported via hash-based distinct
query: `SELECT COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`,
}, {
query: `SELECT a.t1_id, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.t1_id`,
}, {
query: `SELECT a.value, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.value`,
}, {
query: `SELECT count(distinct a.value), SUM(DISTINCT b.t1_id) FROM t1 a, t1 b`,
expectedErr: "VT12001: unsupported: only one DISTINCT aggregation is allowed in a SELECT: sum(distinct b.t1_id) (errno 1235) (sqlstate 42000)",
// Multiple distinct aggregations in join - now supported via hash-based distinct
query: `SELECT count(distinct a.value), SUM(DISTINCT b.t1_id) FROM t1 a, t1 b`,
}, {
query: `SELECT a.value, SUM(DISTINCT b.t1_id), min(DISTINCT a.t1_id) FROM t1 a, t1 b group by a.value`,
}, {
Expand Down
144 changes: 118 additions & 26 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/engine/opcode"
"vitess.io/vitess/go/vt/vtgate/evalengine"
"vitess.io/vitess/go/vt/vthash"
)

// AggregateParams specify the parameters for each aggregation.
Expand Down Expand Up @@ -56,6 +57,11 @@ type AggregateParams struct {
OrigOpcode opcode.AggregateOpcode

CollationEnv *collations.Environment

// UseHashDistinct indicates this aggregation should use hash-based
// distinct tracking instead of sort-based. This is used when multiple
// distinct aggregations with different expressions exist in a query.
UseHashDistinct bool
}

// NewAggregateParam creates a new aggregate param
Expand Down Expand Up @@ -128,6 +134,14 @@ type aggregator interface {
reset()
}

// distinctTracker is an interface for tracking distinct values.
// It can be implemented by sort-based (aggregatorDistinct) or
// hash-based (aggregatorDistinctHash) trackers.
type distinctTracker interface {
shouldReturn(row []sqltypes.Value) (bool, error)
reset()
}

type aggregatorDistinct struct {
column int
last sqltypes.Value
Expand Down Expand Up @@ -160,18 +174,69 @@ func (a *aggregatorDistinct) reset() {
a.last = sqltypes.NULL
}

// aggregatorDistinctHash is a hash-based distinct tracker that uses hash sets
// to track seen values. Unlike aggregatorDistinct which requires sorted data,
// this can handle arbitrary input order. It's used when multiple distinct
// aggregations with different expressions exist in a single query.
type aggregatorDistinctHash struct {
column int
wsColumn int // weight string column for collation support
seen map[vthash.Hash]struct{}
hasher vthash.Hasher
coll collations.ID
collationEnv *collations.Environment
typ querypb.Type
sqlmode evalengine.SQLMode
values *evalengine.EnumSetValues
}

func (a *aggregatorDistinctHash) shouldReturn(row []sqltypes.Value) (bool, error) {
val := row[a.column]
if val.IsNull() {
// NULL values are never counted as distinct duplicates
return false, nil
}

a.hasher.Reset()
err := evalengine.NullsafeHashcode128(&a.hasher, val, a.coll, a.typ, a.sqlmode, a.values)
if err != nil {
// Fallback to weight string if available
if a.wsColumn >= 0 {
val = row[a.wsColumn]
a.hasher.Reset()
err = evalengine.NullsafeHashcode128(&a.hasher, val, collations.Unknown, sqltypes.VarBinary, a.sqlmode, nil)
}
if err != nil {
return false, err
}
}

hash := a.hasher.Sum128()
if _, found := a.seen[hash]; found {
return true, nil // Already seen, skip
}
a.seen[hash] = struct{}{}
return false, nil
}

func (a *aggregatorDistinctHash) reset() {
a.seen = make(map[vthash.Hash]struct{})
}

type aggregatorCount struct {
from int
n int64
distinct aggregatorDistinct
distinct distinctTracker
}

func (a *aggregatorCount) add(row []sqltypes.Value) error {
if row[a.from].IsNull() {
return nil
}
if ret, err := a.distinct.shouldReturn(row); ret {
return err
if a.distinct != nil {
if ret, err := a.distinct.shouldReturn(row); ret {
return err
}
}
a.n++
return nil
Expand All @@ -183,7 +248,9 @@ func (a *aggregatorCount) finish(*evalengine.ExpressionEnv, collations.ID) (sqlt

func (a *aggregatorCount) reset() {
a.n = 0
a.distinct.reset()
if a.distinct != nil {
a.distinct.reset()
}
}

type aggregatorCountStar struct {
Expand Down Expand Up @@ -235,15 +302,17 @@ func (a *aggregatorMinMax) reset() {
type aggregatorSum struct {
from int
sum evalengine.Sum
distinct aggregatorDistinct
distinct distinctTracker
}

func (a *aggregatorSum) add(row []sqltypes.Value) error {
if row[a.from].IsNull() {
return nil
}
if ret, err := a.distinct.shouldReturn(row); ret {
return err
if a.distinct != nil {
if ret, err := a.distinct.shouldReturn(row); ret {
return err
}
}
return a.sum.Add(row[a.from])
}
Expand All @@ -254,7 +323,9 @@ func (a *aggregatorSum) finish(*evalengine.ExpressionEnv, collations.ID) (sqltyp

func (a *aggregatorSum) reset() {
a.sum.Reset()
a.distinct.reset()
if a.distinct != nil {
a.distinct.reset()
}
}

type aggregatorScalar struct {
Expand Down Expand Up @@ -411,6 +482,35 @@ func isComparable(typ sqltypes.Type) bool {
return false
}

// createDistinctTracker creates the appropriate distinct tracker based on the aggregation parameters.
// If UseHashDistinct is true, it creates a hash-based tracker that can handle arbitrary input order.
// Otherwise, it creates a sort-based tracker that requires sorted input.
func createDistinctTracker(aggr *AggregateParams, distinctCol int, wsCol int, sourceType querypb.Type) distinctTracker {
if distinctCol < 0 {
return nil // No distinct tracking needed
}

if aggr.UseHashDistinct {
return &aggregatorDistinctHash{
column: distinctCol,
wsColumn: wsCol,
seen: make(map[vthash.Hash]struct{}),
hasher: vthash.New(),
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
typ: sourceType,
values: aggr.Type.Values(),
}
}

return &aggregatorDistinct{
column: distinctCol,
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
values: aggr.Type.Values(),
}
}

func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams, env *evalengine.ExpressionEnv, collation collations.ID) (*aggregationState, []*querypb.Field, error) {
fields = slice.Map(fields, func(from *querypb.Field) *querypb.Field { return from.CloneVT() })

Expand All @@ -423,12 +523,14 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams, env
targetType := aggr.typ(sourceType, env, collation)

var ag aggregator
var distinct = -1
var distinctCol = -1
var wsCol = -1

if aggr.Opcode.IsDistinct() {
distinct = aggr.KeyCol
distinctCol = aggr.KeyCol
wsCol = aggr.WCol
if aggr.WAssigned() && !isComparable(sourceType) {
distinct = aggr.WCol
distinctCol = aggr.WCol
}
}

Expand All @@ -444,13 +546,8 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams, env

case opcode.AggregateCount, opcode.AggregateCountDistinct:
ag = &aggregatorCount{
from: aggr.Col,
distinct: aggregatorDistinct{
column: distinct,
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
values: aggr.Type.Values(),
},
from: aggr.Col,
distinct: createDistinctTracker(aggr, distinctCol, wsCol, sourceType),
}

case opcode.AggregateSum, opcode.AggregateSumDistinct:
Expand All @@ -463,14 +560,9 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams, env
}

ag = &aggregatorSum{
from: aggr.Col,
sum: sum,
distinct: aggregatorDistinct{
column: distinct,
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
values: aggr.Type.Values(),
},
from: aggr.Col,
sum: sum,
distinct: createDistinctTracker(aggr, distinctCol, wsCol, sourceType),
}

case opcode.AggregateMin:
Expand Down
128 changes: 128 additions & 0 deletions go/vt/vtgate/engine/aggregations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,131 @@ func BenchmarkScalarAggregate(b *testing.B) {
})
}
}

// TestHashBasedDistinct verifies that hash-based distinct tracking correctly
// identifies duplicate values even when data is not sorted.
func TestHashBasedDistinct(t *testing.T) {
// Test data: unsorted values with duplicates
// Values: 1, 3, 2, 1, 3, 4 -> distinct count should be 4 (1, 2, 3, 4)
fields := sqltypes.MakeTestFields("col", "int64")
results := []*sqltypes.Result{{
Fields: fields,
Rows: [][]sqltypes.Value{
{sqltypes.NewInt64(1)},
{sqltypes.NewInt64(3)},
{sqltypes.NewInt64(2)},
{sqltypes.NewInt64(1)}, // duplicate
{sqltypes.NewInt64(3)}, // duplicate
{sqltypes.NewInt64(4)},
},
}}

fp := &fakePrimitive{
allResultsInOneCall: true,
results: results,
}

// Test hash-based distinct count
oa := &ScalarAggregate{
Aggregates: []*AggregateParams{
{
Opcode: AggregateCountDistinct,
Col: 0,
KeyCol: 0,
WCol: -1,
UseHashDistinct: true, // Use hash-based distinct tracking
},
},
Input: fp,
}

result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(result.Rows) != 1 {
t.Fatalf("expected 1 row, got %d", len(result.Rows))
}

// Should count 4 distinct values
val, err := result.Rows[0][0].ToInt64()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if val != 4 {
t.Errorf("expected distinct count of 4, got %d", val)
}
}

// TestMultipleHashBasedDistinct verifies that multiple hash-based distinct
// aggregations can work independently on different columns.
func TestMultipleHashBasedDistinct(t *testing.T) {
// Test data with two integer columns for simpler hash behavior
// col1: 1, 1, 2, 2, 3 -> 3 distinct values
// col2: 10, 20, 10, 30, 40 -> 4 distinct values
fields := sqltypes.MakeTestFields("col1|col2", "int64|int64")
results := []*sqltypes.Result{{
Fields: fields,
Rows: [][]sqltypes.Value{
{sqltypes.NewInt64(1), sqltypes.NewInt64(10)},
{sqltypes.NewInt64(1), sqltypes.NewInt64(20)},
{sqltypes.NewInt64(2), sqltypes.NewInt64(10)},
{sqltypes.NewInt64(2), sqltypes.NewInt64(30)},
{sqltypes.NewInt64(3), sqltypes.NewInt64(40)},
},
}}

fp := &fakePrimitive{
allResultsInOneCall: true,
results: results,
}

// Test two hash-based distinct counts on different columns
oa := &ScalarAggregate{
Aggregates: []*AggregateParams{
{
Opcode: AggregateCountDistinct,
Col: 0,
KeyCol: 0,
WCol: -1,
UseHashDistinct: true,
},
{
Opcode: AggregateCountDistinct,
Col: 1,
KeyCol: 1,
WCol: -1,
UseHashDistinct: true,
},
},
Input: fp,
}

result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(result.Rows) != 1 {
t.Fatalf("expected 1 row, got %d", len(result.Rows))
}

// First column should have 3 distinct values
val1, err := result.Rows[0][0].ToInt64()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if val1 != 3 {
t.Errorf("expected first distinct count of 3, got %d", val1)
}

// Second column should have 4 distinct values
val2, err := result.Rows[0][1].ToInt64()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if val2 != 4 {
t.Errorf("expected second distinct count of 4, got %d", val2)
}
}
2 changes: 2 additions & 0 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega
aggrParam.OrigOpcode = aggr.OriginalOpCode
aggrParam.WCol = aggr.WSOffset
aggrParam.Type = aggr.GetTypeCollation(ctx)
// Pass hash-based distinct flag from planner to engine
aggrParam.UseHashDistinct = aggr.UseHashDistinct || op.UseHashDistinct
aggregates = append(aggregates, aggrParam)
}

Expand Down
Loading
Loading