Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 115 additions & 210 deletions src/query/service/src/physical_plans/physical_aggregate_final.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use databend_common_sql::optimizer::ir::SExpr;
use databend_common_sql::plans::Aggregate;
use databend_common_sql::plans::AggregateMode;
use databend_common_sql::plans::ConstantTableScan;
use databend_common_sql::plans::ScalarItem;
use itertools::Itertools;

use super::AggregateExpand;
Expand Down Expand Up @@ -259,111 +260,8 @@ impl PhysicalPlanBuilder {
.map(|item| Ok(item.scalar.as_expr()?.sql_display()))
.collect::<Result<Vec<_>>>()?;

let mut agg_funcs: Vec<AggregateFunctionDesc> = agg
.aggregate_functions
.iter()
.map(|v| match &v.scalar {
ScalarExpr::AggregateFunction(agg) => {
let arg_indices = agg
.args
.iter()
.map(|arg| {
if let ScalarExpr::BoundColumnRef(col) = arg {
Ok(col.column.index)
} else {
Err(ErrorCode::Internal(
"Aggregate function argument must be a BoundColumnRef"
.to_string(),
))
}
})
.collect::<Result<Vec<_>>>()?;
let args = arg_indices
.iter()
.map(|i| {
Ok(input_schema
.field_with_name(&i.to_string())?
.data_type()
.clone())
})
.collect::<Result<_>>()?;
let sort_desc_indices = agg.sort_descs
.iter()
.map(|desc| {
if let ScalarExpr::BoundColumnRef(col) = &desc.expr {
Ok(col.column.index)
} else {
Err(ErrorCode::Internal(
"Aggregate function description must be a BoundColumnRef"
.to_string(),
))
}
})
.collect::<Result<_>>()?;
let sort_descs = agg.sort_descs
.iter()
.map(|desc| desc.try_into())
.collect::<Result<_>>()?;
Ok(AggregateFunctionDesc {
sig: AggregateFunctionSignature {
name: agg.func_name.clone(),
udaf: None,
return_type: *agg.return_type.clone(),
args,
params: agg.params.clone(),
sort_descs,
},
output_column: v.index,
arg_indices,
sort_desc_indices,
display: v.scalar.as_expr()?.sql_display(),
})
}
ScalarExpr::UDAFCall(udaf) => {
let arg_indices = udaf
.arguments
.iter()
.map(|arg| {
if let ScalarExpr::BoundColumnRef(col) = arg {
Ok(col.column.index)
} else {
Err(ErrorCode::Internal(
"Aggregate function argument must be a BoundColumnRef"
.to_string(),
))
}
})
.collect::<Result<Vec<_>>>()?;
let args = arg_indices
.iter()
.map(|i| {
Ok(input_schema
.field_with_name(&i.to_string())?
.data_type()
.clone())
})
.collect::<Result<_>>()?;

Ok(AggregateFunctionDesc {
sig: AggregateFunctionSignature {
name: udaf.name.clone(),
udaf: Some((udaf.udf_type.clone(), udaf.state_fields.clone())),
return_type: *udaf.return_type.clone(),
args,
params: vec![],
sort_descs: vec![],
},
output_column: v.index,
arg_indices,
sort_desc_indices: vec![],
display: v.scalar.as_expr()?.sql_display(),
})
}
_ => Err(ErrorCode::Internal(
"Expected aggregate function".to_string(),
)),
})
.collect::<Result<_>>()?;
let mut agg_funcs =
build_aggregate_function(&agg.aggregate_functions, &input_schema)?;

let settings = self.ctx.get_settings();
let mut group_by_shuffle_mode = settings.get_group_by_shuffle_mode()?;
Expand Down Expand Up @@ -507,111 +405,8 @@ impl PhysicalPlanBuilder {
aggregate.input.output_schema()?
};

let mut agg_funcs: Vec<AggregateFunctionDesc> = agg
.aggregate_functions
.iter()
.map(|v| match &v.scalar {
ScalarExpr::AggregateFunction(agg) => {
let arg_indices = agg
.args
.iter()
.map(|arg| {
if let ScalarExpr::BoundColumnRef(col) = arg {
Ok(col.column.index)
} else {
Err(ErrorCode::Internal(
"Aggregate function argument must be a BoundColumnRef"
.to_string(),
))
}
})
.collect::<Result<Vec<_>>>()?;
let sort_desc_indices = agg.sort_descs
.iter()
.map(|desc| {
if let ScalarExpr::BoundColumnRef(col) = &desc.expr {
Ok(col.column.index)
} else {
Err(ErrorCode::Internal(
"Aggregate function sort description must be a BoundColumnRef"
.to_string(),
))
}
})
.collect::<Result<_>>()?;
let args = arg_indices
.iter()
.map(|i| {
Ok(input_schema
.field_with_name(&i.to_string())?
.data_type()
.clone())
})
.collect::<Result<_>>()?;
let sort_descs = agg.sort_descs
.iter()
.map(|desc| desc.try_into())
.collect::<Result<_>>()?;
Ok(AggregateFunctionDesc {
sig: AggregateFunctionSignature {
name: agg.func_name.clone(),
udaf: None,
return_type: *agg.return_type.clone(),
args,
params: agg.params.clone(),
sort_descs,
},
output_column: v.index,
arg_indices,
sort_desc_indices,
display: v.scalar.as_expr()?.sql_display(),
})
}
ScalarExpr::UDAFCall(udaf) => {
let arg_indices = udaf
.arguments
.iter()
.map(|arg| {
if let ScalarExpr::BoundColumnRef(col) = arg {
Ok(col.column.index)
} else {
Err(ErrorCode::Internal(
"Aggregate function argument must be a BoundColumnRef"
.to_string(),
))
}
})
.collect::<Result<Vec<_>>>()?;
let args = arg_indices
.iter()
.map(|i| {
Ok(input_schema
.field_with_name(&i.to_string())?
.data_type()
.clone())
})
.collect::<Result<_>>()?;

Ok(AggregateFunctionDesc {
sig: AggregateFunctionSignature {
name: udaf.name.clone(),
udaf: Some((udaf.udf_type.clone(), udaf.state_fields.clone())),
return_type: *udaf.return_type.clone(),
args,
params: vec![],
sort_descs: vec![],
},
output_column: v.index,
arg_indices,
sort_desc_indices: vec![],
display: v.scalar.as_expr()?.sql_display(),
})
}
_ => Err(ErrorCode::Internal(
"Expected aggregate function".to_string(),
)),
})
.collect::<Result<_>>()?;
let mut agg_funcs =
build_aggregate_function(&agg.aggregate_functions, &input_schema)?;

if let Some(grouping_sets) = agg.grouping_sets.as_ref() {
// The argument types are wrapped nullable due to `AggregateExpand` plan. We should recover them to original types.
Expand Down Expand Up @@ -676,3 +471,113 @@ impl PhysicalPlanBuilder {
Ok(result)
}
}

fn build_aggregate_function(
agg_functions: &[ScalarItem],
input_schema: &DataSchemaRef,
) -> Result<Vec<AggregateFunctionDesc>> {
agg_functions
.iter()
.map(|v| match &v.scalar {
ScalarExpr::AggregateFunction(agg) => {
let arg_indices = agg
.args
.iter()
.map(|arg| {
if let ScalarExpr::BoundColumnRef(col) = arg {
Ok(col.column.index)
} else {
Err(ErrorCode::Internal(
"Aggregate function argument must be a BoundColumnRef".to_string(),
))
}
})
.collect::<Result<Vec<_>>>()?;
let sort_desc_indices = agg
.sort_descs
.iter()
.map(|desc| {
if let ScalarExpr::BoundColumnRef(col) = &desc.expr {
Ok(col.column.index)
} else {
Err(ErrorCode::Internal(
"Aggregate function sort description must be a BoundColumnRef"
.to_string(),
))
}
})
.collect::<Result<_>>()?;
let args = arg_indices
.iter()
.map(|i| {
Ok(input_schema
.field_with_name(&i.to_string())?
.data_type()
.clone())
})
.collect::<Result<_>>()?;
let sort_descs = agg
.sort_descs
.iter()
.map(|desc| desc.try_into())
.collect::<Result<_>>()?;
Ok(AggregateFunctionDesc {
sig: AggregateFunctionSignature {
name: agg.func_name.clone(),
udaf: None,
return_type: *agg.return_type.clone(),
args,
params: agg.params.clone(),
sort_descs,
},
output_column: v.index,
arg_indices,
sort_desc_indices,
display: v.scalar.as_expr()?.sql_display(),
})
}
ScalarExpr::UDAFCall(udaf) => {
let arg_indices = udaf
.arguments
.iter()
.map(|arg| {
if let ScalarExpr::BoundColumnRef(col) = arg {
Ok(col.column.index)
} else {
Err(ErrorCode::Internal(
"Aggregate function argument must be a BoundColumnRef".to_string(),
))
}
})
.collect::<Result<Vec<_>>>()?;
let args = arg_indices
.iter()
.map(|i| {
Ok(input_schema
.field_with_name(&i.to_string())?
.data_type()
.clone())
})
.collect::<Result<_>>()?;

Ok(AggregateFunctionDesc {
sig: AggregateFunctionSignature {
name: udaf.name.clone(),
udaf: Some((udaf.udf_type.clone(), udaf.state_fields.clone())),
return_type: *udaf.return_type.clone(),
args,
params: vec![],
sort_descs: vec![],
},
output_column: v.index,
arg_indices,
sort_desc_indices: vec![],
display: v.scalar.as_expr()?.sql_display(),
})
}
_ => Err(ErrorCode::Internal(
"Expected aggregate function".to_string(),
)),
})
.collect::<Result<_>>()
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ mod new_transform_final_aggregate;
mod transform_partition_bucket_scatter;

pub use datablock_splitter::split_partitioned_meta_into_datablocks;
pub use new_aggregate_spiller::LocalPartitionStream;
pub use new_aggregate_spiller::NewAggregateSpillReader;
pub use new_aggregate_spiller::NewAggregateSpiller;
pub use new_aggregate_spiller::PartitionStream;
pub use new_aggregate_spiller::SharedPartitionStream;
pub use new_final_aggregate_state::FinalAggregateSharedState;
pub use new_transform_aggregate_partial::NewTransformPartialAggregate;
Expand Down
Loading
Loading