Skip to content

Commit a3fb83a

Browse files
authored
fix: distinct_eliminated is rewritten as distinct_on_group_key (#19142)
* fix: distinct_eliminated is rewritten as distinct_on_group_key * chore: codefmt * chore: codefmt * fix: group by set * fix: explain for agg * chore: fix explain_native/aggregate.test * chore: fix explain agg * fix: stop collapsing distinct on grouping sets to 0/1
1 parent 937a06e commit a3fb83a

File tree

4 files changed

+133
-60
lines changed

4 files changed

+133
-60
lines changed

src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/normalize_aggregate.rs

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ use crate::optimizer::Optimizer;
2222
use crate::optimizer::ir::SExpr;
2323
use crate::plans::Aggregate;
2424
use crate::plans::BoundColumnRef;
25+
use crate::plans::ConstantExpr;
2526
use crate::plans::EvalScalar;
27+
use crate::plans::FunctionCall;
2628
use crate::plans::RelOperator;
2729
use crate::plans::ScalarExpr;
2830
use crate::plans::ScalarItem;
@@ -54,6 +56,7 @@ impl RuleNormalizeAggregateOptimizer {
5456
let mut work_expr = None;
5557
let mut alias_functions_index = vec![];
5658
let mut new_aggregate_functions = Vec::with_capacity(aggregate.aggregate_functions.len());
59+
let mut post_aggregate_scalars = Vec::new();
5760

5861
let mut rewritten = false;
5962

@@ -80,10 +83,12 @@ impl RuleNormalizeAggregateOptimizer {
8083
continue;
8184
}
8285

83-
// rewrite count(distinct items) to count() if items in group by
84-
let distinct_eliminated = ((function.distinct && function.func_name == "count")
86+
// rewrite count(distinct item)/uniq/count_distinct on grouping key to 1 (or 0 if null)
87+
let distinct_on_group_key = ((function.distinct && function.func_name == "count")
8588
|| function.func_name == "uniq"
8689
|| function.func_name == "count_distinct")
90+
&& function.args.len() == 1
91+
&& aggregate.grouping_sets.is_none()
8792
&& function.args.iter().all(|expr| {
8893
if let ScalarExpr::BoundColumnRef(r) = expr {
8994
aggregate
@@ -95,16 +100,57 @@ impl RuleNormalizeAggregateOptimizer {
95100
}
96101
});
97102

98-
if distinct_eliminated {
103+
if distinct_on_group_key {
99104
rewritten = true;
100-
let mut new_function = function.clone();
101-
new_function.args = vec![];
102-
new_function.func_name = "count".to_string();
103105

104-
new_aggregate_functions.push(ScalarItem {
106+
// Grouping sets rewrite will wrap grouping keys into nullable and inject NULLs
107+
// for sets where the key is absent, so treat them as nullable even if the
108+
// original column type is non-nullable.
109+
let mut nullable = function.args[0].data_type()?.is_nullable_or_null();
110+
if !nullable {
111+
if let Some(grouping_sets) = &aggregate.grouping_sets {
112+
if !grouping_sets.sets.is_empty() {
113+
nullable = true;
114+
}
115+
}
116+
}
117+
118+
let scalar = if nullable {
119+
let not_null_check = ScalarExpr::FunctionCall(FunctionCall {
120+
span: None,
121+
func_name: "is_not_null".to_string(),
122+
params: vec![],
123+
arguments: vec![function.args[0].clone()],
124+
});
125+
126+
ScalarExpr::FunctionCall(FunctionCall {
127+
span: None,
128+
func_name: "if".to_string(),
129+
params: vec![],
130+
arguments: vec![
131+
not_null_check,
132+
ScalarExpr::ConstantExpr(ConstantExpr {
133+
span: None,
134+
value: 1u64.into(),
135+
}),
136+
ScalarExpr::ConstantExpr(ConstantExpr {
137+
span: None,
138+
value: 0u64.into(),
139+
}),
140+
],
141+
})
142+
} else {
143+
ScalarExpr::ConstantExpr(ConstantExpr {
144+
span: None,
145+
value: 1u64.into(),
146+
})
147+
};
148+
149+
post_aggregate_scalars.push(ScalarItem {
105150
index: aggregate_function.index,
106-
scalar: ScalarExpr::AggregateFunction(new_function),
151+
scalar,
107152
});
153+
108154
continue;
109155
}
110156
}
@@ -130,12 +176,10 @@ impl RuleNormalizeAggregateOptimizer {
130176
Arc::new(s_expr.child(0)?.clone()),
131177
);
132178

179+
let mut scalar_items = Vec::new();
180+
133181
if let Some((work_index, work_c)) = work_expr {
134-
if alias_functions_index.len() < 2 {
135-
return Ok(new_aggregate);
136-
}
137-
if !alias_functions_index.is_empty() {
138-
let mut scalar_items = Vec::with_capacity(alias_functions_index.len());
182+
if alias_functions_index.len() >= 2 {
139183
for (alias_function_index, _alias_function) in alias_functions_index {
140184
scalar_items.push(ScalarItem {
141185
index: alias_function_index,
@@ -156,21 +200,24 @@ impl RuleNormalizeAggregateOptimizer {
156200
}),
157201
})
158202
}
159-
160-
new_aggregate = SExpr::create_unary(
161-
Arc::new(
162-
EvalScalar {
163-
items: scalar_items,
164-
}
165-
.into(),
166-
),
167-
Arc::new(new_aggregate),
168-
);
169203
}
170-
Ok(new_aggregate)
171-
} else {
172-
Ok(new_aggregate)
173204
}
205+
206+
scalar_items.extend(post_aggregate_scalars);
207+
208+
if !scalar_items.is_empty() {
209+
new_aggregate = SExpr::create_unary(
210+
Arc::new(
211+
EvalScalar {
212+
items: scalar_items,
213+
}
214+
.into(),
215+
),
216+
Arc::new(new_aggregate),
217+
);
218+
}
219+
220+
Ok(new_aggregate)
174221
}
175222
}
176223

tests/sqllogictests/suites/base/03_common/03_0001_select_aggregator.test

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,21 @@ select sum(a), sum(b), sum(c), sum(e) from ( select (number % 8)::UInt64 as a,(n
3030
----
3131
342 396 450 100
3232

33+
query III
34+
select k, count(distinct k), count(*) from (select number % 3 as k from numbers(5)) group by k order by k
35+
----
36+
0 1 2
37+
1 1 2
38+
2 1 1
39+
40+
query II
41+
select n, count(distinct n), count(*) from (select cast(NULL as Nullable(Int32)) as n from numbers(3)) group by n
42+
----
43+
NULL 0 3
44+
45+
query II rowsort
46+
select k, count(distinct k) from (select number % 2 as k from numbers(4)) group by grouping sets((k),()) order by k
47+
----
48+
0 1
49+
1 1
50+
NULL 2

tests/sqllogictests/suites/mode/standalone/explain/aggregate.test

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -457,28 +457,32 @@ query T
457457
EXPLAIN SELECT referer, avg(isrefresh), count(distinct referer) FROM t GROUP BY referer;
458458
----
459459
EvalScalar
460-
├── output columns: [count(DISTINCT referer) (#4), t.referer (#0), sum(isrefresh) / if(count(isrefresh) = 0, 1, count(isrefresh)) (#5)]
460+
├── output columns: [t.referer (#0), count(DISTINCT referer) (#4), sum(isrefresh) / if(count(isrefresh) = 0, 1, count(isrefresh)) (#5)]
461461
├── expressions: [sum(isrefresh) (#2) / CAST(if(CAST(count(isrefresh) (#3) = 0 AS Boolean NULL), 1, count(isrefresh) (#3)) AS UInt64 NULL)]
462462
├── estimated rows: 1.00
463-
└── AggregateFinal
464-
├── output columns: [sum(isrefresh) (#2), count(isrefresh) (#3), count(DISTINCT referer) (#4), t.referer (#0)]
465-
├── group by: [referer]
466-
├── aggregate functions: [sum(isrefresh), count(), count()]
463+
└── EvalScalar
464+
├── output columns: [sum(isrefresh) (#2), count(isrefresh) (#3), t.referer (#0), count(DISTINCT referer) (#4)]
465+
├── expressions: [1]
467466
├── estimated rows: 1.00
468-
└── AggregatePartial
467+
└── AggregateFinal
468+
├── output columns: [sum(isrefresh) (#2), count(isrefresh) (#3), t.referer (#0)]
469469
├── group by: [referer]
470-
├── aggregate functions: [sum(isrefresh), count(), count()]
470+
├── aggregate functions: [sum(isrefresh), count()]
471471
├── estimated rows: 1.00
472-
└── TableScan
473-
├── table: default.default.t
474-
├── scan id: 0
475-
├── output columns: [referer (#0), isrefresh (#1)]
476-
├── read rows: 0
477-
├── read size: 0
478-
├── partitions total: 0
479-
├── partitions scanned: 0
480-
├── push downs: [filters: [], limit: NONE]
481-
└── estimated rows: 0.00
472+
└── AggregatePartial
473+
├── group by: [referer]
474+
├── aggregate functions: [sum(isrefresh), count()]
475+
├── estimated rows: 1.00
476+
└── TableScan
477+
├── table: default.default.t
478+
├── scan id: 0
479+
├── output columns: [referer (#0), isrefresh (#1)]
480+
├── read rows: 0
481+
├── read size: 0
482+
├── partitions total: 0
483+
├── partitions scanned: 0
484+
├── push downs: [filters: [], limit: NONE]
485+
└── estimated rows: 0.00
482486

483487
query T
484488
EXPLAIN SELECT referer, isrefresh, count() FROM t GROUP BY referer, isrefresh order by referer, isrefresh desc limit 10;

tests/sqllogictests/suites/mode/standalone/explain_native/aggregate.test

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -292,28 +292,32 @@ query T
292292
EXPLAIN SELECT referer, avg(isrefresh), count(distinct referer) FROM t GROUP BY referer;
293293
----
294294
EvalScalar
295-
├── output columns: [count(DISTINCT referer) (#4), t.referer (#0), sum(isrefresh) / if(count(isrefresh) = 0, 1, count(isrefresh)) (#5)]
295+
├── output columns: [t.referer (#0), count(DISTINCT referer) (#4), sum(isrefresh) / if(count(isrefresh) = 0, 1, count(isrefresh)) (#5)]
296296
├── expressions: [sum(isrefresh) (#2) / CAST(if(CAST(count(isrefresh) (#3) = 0 AS Boolean NULL), 1, count(isrefresh) (#3)) AS UInt64 NULL)]
297297
├── estimated rows: 1.00
298-
└── AggregateFinal
299-
├── output columns: [sum(isrefresh) (#2), count(isrefresh) (#3), count(DISTINCT referer) (#4), t.referer (#0)]
300-
├── group by: [referer]
301-
├── aggregate functions: [sum(isrefresh), count(), count()]
298+
└── EvalScalar
299+
├── output columns: [sum(isrefresh) (#2), count(isrefresh) (#3), t.referer (#0), count(DISTINCT referer) (#4)]
300+
├── expressions: [1]
302301
├── estimated rows: 1.00
303-
└── AggregatePartial
302+
└── AggregateFinal
303+
├── output columns: [sum(isrefresh) (#2), count(isrefresh) (#3), t.referer (#0)]
304304
├── group by: [referer]
305-
├── aggregate functions: [sum(isrefresh), count(), count()]
305+
├── aggregate functions: [sum(isrefresh), count()]
306306
├── estimated rows: 1.00
307-
└── TableScan
308-
├── table: default.default.t
309-
├── scan id: 0
310-
├── output columns: [referer (#0), isrefresh (#1)]
311-
├── read rows: 0
312-
├── read size: 0
313-
├── partitions total: 0
314-
├── partitions scanned: 0
315-
├── push downs: [filters: [], limit: NONE]
316-
└── estimated rows: 0.00
307+
└── AggregatePartial
308+
├── group by: [referer]
309+
├── aggregate functions: [sum(isrefresh), count()]
310+
├── estimated rows: 1.00
311+
└── TableScan
312+
├── table: default.default.t
313+
├── scan id: 0
314+
├── output columns: [referer (#0), isrefresh (#1)]
315+
├── read rows: 0
316+
├── read size: 0
317+
├── partitions total: 0
318+
├── partitions scanned: 0
319+
├── push downs: [filters: [], limit: NONE]
320+
└── estimated rows: 0.00
317321

318322
statement ok
319323
DROP TABLE IF EXISTS t;

0 commit comments

Comments
 (0)