Skip to content

Commit a9d4d52

Browse files
authored
feat(substrait): AggregateRel grouping_expression support (#13173)
1 parent b40a298 commit a9d4d52

File tree

4 files changed

+210
-36
lines changed

4 files changed

+210
-36
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use datafusion::logical_expr::{
3333
expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr,
3434
ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values,
3535
};
36+
use substrait::proto::aggregate_rel::Grouping;
3637
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
3738
use substrait::proto::expression_reference::ExprType;
3839
use url::Url;
@@ -665,39 +666,48 @@ pub async fn from_substrait_rel(
665666
let input = LogicalPlanBuilder::from(
666667
from_substrait_rel(ctx, input, extensions).await?,
667668
);
668-
let mut group_expr = vec![];
669-
let mut aggr_expr = vec![];
669+
let mut ref_group_exprs = vec![];
670+
671+
for e in &agg.grouping_expressions {
672+
let x =
673+
from_substrait_rex(ctx, e, input.schema(), extensions).await?;
674+
ref_group_exprs.push(x);
675+
}
676+
677+
let mut group_exprs = vec![];
678+
let mut aggr_exprs = vec![];
670679

671680
match agg.groupings.len() {
672681
1 => {
673-
for e in &agg.groupings[0].grouping_expressions {
674-
let x =
675-
from_substrait_rex(ctx, e, input.schema(), extensions)
676-
.await?;
677-
group_expr.push(x);
678-
}
682+
group_exprs.extend_from_slice(
683+
&from_substrait_grouping(
684+
ctx,
685+
&agg.groupings[0],
686+
&ref_group_exprs,
687+
input.schema(),
688+
extensions,
689+
)
690+
.await?,
691+
);
679692
}
680693
_ => {
681694
let mut grouping_sets = vec![];
682695
for grouping in &agg.groupings {
683-
let mut grouping_set = vec![];
684-
for e in &grouping.grouping_expressions {
685-
let x = from_substrait_rex(
686-
ctx,
687-
e,
688-
input.schema(),
689-
extensions,
690-
)
691-
.await?;
692-
grouping_set.push(x);
693-
}
696+
let grouping_set = from_substrait_grouping(
697+
ctx,
698+
grouping,
699+
&ref_group_exprs,
700+
input.schema(),
701+
extensions,
702+
)
703+
.await?;
694704
grouping_sets.push(grouping_set);
695705
}
696706
// Single-element grouping expression of type Expr::GroupingSet.
697707
// Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when
698708
// parsed by the producer and consumer, since Substrait does not have a type dedicated
699709
// to ROLLUP. Only vector of Groupings (grouping sets) is available.
700-
group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets(
710+
group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets(
701711
grouping_sets,
702712
)));
703713
}
@@ -755,9 +765,9 @@ pub async fn from_substrait_rel(
755765
"Aggregate without aggregate function is not supported"
756766
),
757767
};
758-
aggr_expr.push(agg_func?.as_ref().clone());
768+
aggr_exprs.push(agg_func?.as_ref().clone());
759769
}
760-
input.aggregate(group_expr, aggr_expr)?.build()
770+
input.aggregate(group_exprs, aggr_exprs)?.build()
761771
} else {
762772
not_impl_err!("Aggregate without an input is not valid")
763773
}
@@ -2762,6 +2772,29 @@ fn from_substrait_null(
27622772
}
27632773
}
27642774

2775+
#[allow(deprecated)]
2776+
async fn from_substrait_grouping(
2777+
ctx: &SessionContext,
2778+
grouping: &Grouping,
2779+
expressions: &[Expr],
2780+
input_schema: &DFSchemaRef,
2781+
extensions: &Extensions,
2782+
) -> Result<Vec<Expr>> {
2783+
let mut group_exprs = vec![];
2784+
if !grouping.grouping_expressions.is_empty() {
2785+
for e in &grouping.grouping_expressions {
2786+
let expr = from_substrait_rex(ctx, e, input_schema, extensions).await?;
2787+
group_exprs.push(expr);
2788+
}
2789+
return Ok(group_exprs);
2790+
}
2791+
for idx in &grouping.expression_references {
2792+
let e = &expressions[*idx as usize];
2793+
group_exprs.push(e.clone());
2794+
}
2795+
Ok(group_exprs)
2796+
}
2797+
27652798
fn from_substrait_field_reference(
27662799
field_ref: &FieldReference,
27672800
input_schema: &DFSchema,

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ pub fn to_substrait_rel(
361361
}
362362
LogicalPlan::Aggregate(agg) => {
363363
let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?;
364-
let groupings = to_substrait_groupings(
364+
let (grouping_expressions, groupings) = to_substrait_groupings(
365365
ctx,
366366
&agg.group_expr,
367367
agg.input.schema(),
@@ -377,7 +377,7 @@ pub fn to_substrait_rel(
377377
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
378378
common: None,
379379
input: Some(input),
380-
grouping_expressions: vec![],
380+
grouping_expressions,
381381
groupings,
382382
measures,
383383
advanced_extension: None,
@@ -774,14 +774,20 @@ pub fn parse_flat_grouping_exprs(
774774
exprs: &[Expr],
775775
schema: &DFSchemaRef,
776776
extensions: &mut Extensions,
777+
ref_group_exprs: &mut Vec<Expression>,
777778
) -> Result<Grouping> {
778-
let grouping_expressions = exprs
779-
.iter()
780-
.map(|e| to_substrait_rex(ctx, e, schema, 0, extensions))
781-
.collect::<Result<Vec<_>>>()?;
779+
let mut expression_references = vec![];
780+
let mut grouping_expressions = vec![];
781+
782+
for e in exprs {
783+
let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?;
784+
grouping_expressions.push(rex.clone());
785+
ref_group_exprs.push(rex);
786+
expression_references.push((ref_group_exprs.len() - 1) as u32);
787+
}
782788
Ok(Grouping {
783789
grouping_expressions,
784-
expression_references: vec![],
790+
expression_references,
785791
})
786792
}
787793

@@ -790,16 +796,25 @@ pub fn to_substrait_groupings(
790796
exprs: &[Expr],
791797
schema: &DFSchemaRef,
792798
extensions: &mut Extensions,
793-
) -> Result<Vec<Grouping>> {
794-
match exprs.len() {
799+
) -> Result<(Vec<Expression>, Vec<Grouping>)> {
800+
let mut ref_group_exprs = vec![];
801+
let groupings = match exprs.len() {
795802
1 => match &exprs[0] {
796803
Expr::GroupingSet(gs) => match gs {
797804
GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented(
798805
"GroupingSet CUBE is not yet supported".to_string(),
799806
)),
800807
GroupingSet::GroupingSets(sets) => Ok(sets
801808
.iter()
802-
.map(|set| parse_flat_grouping_exprs(ctx, set, schema, extensions))
809+
.map(|set| {
810+
parse_flat_grouping_exprs(
811+
ctx,
812+
set,
813+
schema,
814+
extensions,
815+
&mut ref_group_exprs,
816+
)
817+
})
803818
.collect::<Result<Vec<_>>>()?),
804819
GroupingSet::Rollup(set) => {
805820
let mut sets: Vec<Vec<Expr>> = vec![vec![]];
@@ -810,19 +825,34 @@ pub fn to_substrait_groupings(
810825
.iter()
811826
.rev()
812827
.map(|set| {
813-
parse_flat_grouping_exprs(ctx, set, schema, extensions)
828+
parse_flat_grouping_exprs(
829+
ctx,
830+
set,
831+
schema,
832+
extensions,
833+
&mut ref_group_exprs,
834+
)
814835
})
815836
.collect::<Result<Vec<_>>>()?)
816837
}
817838
},
818839
_ => Ok(vec![parse_flat_grouping_exprs(
819-
ctx, exprs, schema, extensions,
840+
ctx,
841+
exprs,
842+
schema,
843+
extensions,
844+
&mut ref_group_exprs,
820845
)?]),
821846
},
822847
_ => Ok(vec![parse_flat_grouping_exprs(
823-
ctx, exprs, schema, extensions,
848+
ctx,
849+
exprs,
850+
schema,
851+
extensions,
852+
&mut ref_group_exprs,
824853
)?]),
825-
}
854+
}?;
855+
Ok((ref_group_exprs, groupings))
826856
}
827857

828858
#[allow(deprecated)]

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,19 @@ async fn aggregate_wo_projection_consume() -> Result<()> {
665665
.await
666666
}
667667

668+
#[tokio::test]
669+
async fn aggregate_wo_projection_group_expression_ref_consume() -> Result<()> {
670+
let proto_plan =
671+
read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json");
672+
673+
assert_expected_plan_substrait(
674+
proto_plan,
675+
"Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\
676+
\n TableScan: data projection=[a]",
677+
)
678+
.await
679+
}
680+
668681
#[tokio::test]
669682
async fn aggregate_wo_projection_sorted_consume() -> Result<()> {
670683
let proto_plan =
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
{
2+
"extensionUris": [
3+
{
4+
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml"
5+
}
6+
],
7+
"extensions": [
8+
{
9+
"extensionFunction": {
10+
"functionAnchor": 185,
11+
"name": "count:any"
12+
}
13+
}
14+
],
15+
"relations": [
16+
{
17+
"root": {
18+
"input": {
19+
"aggregate": {
20+
"input": {
21+
"read": {
22+
"common": {
23+
"direct": {}
24+
},
25+
"baseSchema": {
26+
"names": [
27+
"a"
28+
],
29+
"struct": {
30+
"types": [
31+
{
32+
"i64": {
33+
"nullability": "NULLABILITY_NULLABLE"
34+
}
35+
}
36+
],
37+
"nullability": "NULLABILITY_NULLABLE"
38+
}
39+
},
40+
"namedTable": {
41+
"names": [
42+
"data"
43+
]
44+
}
45+
}
46+
},
47+
"grouping_expressions": [
48+
{
49+
"selection": {
50+
"directReference": {
51+
"structField": {}
52+
},
53+
"rootReference": {}
54+
}
55+
}
56+
],
57+
"groupings": [
58+
{
59+
"expression_references": [0]
60+
}
61+
],
62+
"measures": [
63+
{
64+
"measure": {
65+
"functionReference": 185,
66+
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
67+
"outputType": {
68+
"i64": {}
69+
},
70+
"arguments": [
71+
{
72+
"value": {
73+
"selection": {
74+
"directReference": {
75+
"structField": {}
76+
},
77+
"rootReference": {}
78+
}
79+
}
80+
}
81+
]
82+
}
83+
}
84+
]
85+
}
86+
},
87+
"names": [
88+
"a",
89+
"countA"
90+
]
91+
}
92+
}
93+
],
94+
"version": {
95+
"minorNumber": 54,
96+
"producer": "subframe"
97+
}
98+
}

0 commit comments

Comments
 (0)