From b0063625aaa28948a835f474177aecbcf4b31e67 Mon Sep 17 00:00:00 2001 From: Qi Zhu <821684824@qq.com> Date: Sat, 8 Nov 2025 17:05:18 +0800 Subject: [PATCH] Change solution to physical optimization --- Cargo.lock | 1 + .../functions-table/src/generate_series.rs | 8 +- datafusion/physical-optimizer/Cargo.toml | 1 + .../src/topk_aggregation.rs | 99 ++++++++++++++++++- 4 files changed, 102 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f500265108ff..add1dfcf0fd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2527,6 +2527,7 @@ dependencies = [ "datafusion-expr", "datafusion-expr-common", "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index d71c5945aafc..d40f6ef5257c 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -534,10 +534,12 @@ impl GenerateSeriesFuncImpl { }; } + // Relax the nullable to true since we will optimize to Max/Min aggregate for + // limit 1 case, so the nullable will check failing if let schema = Arc::new(Schema::new(vec![Field::new( "value", DataType::Int64, - false, + true, )])); if normalize_args.len() != exprs.len() { @@ -629,7 +631,7 @@ impl GenerateSeriesFuncImpl { let schema = Arc::new(Schema::new(vec![Field::new( "value", DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), - false, + true, )])); // Check if any argument is null @@ -668,7 +670,7 @@ impl GenerateSeriesFuncImpl { let schema = Arc::new(Schema::new(vec![Field::new( "value", DataType::Timestamp(TimeUnit::Nanosecond, None), - false, + true, )])); // Parse start date diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 4df011fc0a05..e52f807787d2 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -45,6 +45,7 @@ datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +datafusion-functions-aggregate = {workspace = true} datafusion-physical-plan = { workspace = true } datafusion-pruning = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index b7505f0df4ed..a362135d0e59 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -24,8 +24,12 @@ use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_physical_expr::expressions::Column; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::Column as PhysicalColumn; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::sorts::sort::SortExec; @@ -85,15 +89,101 @@ impl TopKAggregation { Some(Arc::new(new_aggr)) } + fn try_convert_topk_to_minmax( + sort_exec: &SortExec, + ) -> Option> { + let fetch = sort_exec.fetch()?; + if fetch != 1 { + return None; + } + + let sort_exprs = sort_exec.expr(); + if sort_exprs.len() != 1 { + return None; + } + + let sort_expr = &sort_exprs[0]; + let order_desc = sort_expr.options.descending; + let sort_col = sort_expr.expr.as_any().downcast_ref::()?; + + let input = sort_exec.input(); + let input_schema = input.schema(); + let col_index = sort_col.index(); + let field = input_schema.field(col_index); + let col_type = field.data_type(); + let col_name = field.name().to_string(); + + match col_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + | DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) => {} + _ => return None, + } + + let agg_udf = if order_desc { + datafusion_expr::AggregateUDF::new_from_impl( + datafusion_functions_aggregate::min_max::Max::default(), + ) + } else { + datafusion_expr::AggregateUDF::new_from_impl( + datafusion_functions_aggregate::min_max::Min::default(), + ) + }; + + let phys_col: Arc = + Arc::new(PhysicalColumn::new(&col_name, col_index)); + + let agg_fn_expr = AggregateExprBuilder::new(Arc::new(agg_udf), vec![phys_col]) + .schema(Arc::clone(&input_schema)) + .alias(&col_name) + .build() + .ok()?; + + let agg_physical: Arc = + Arc::new(agg_fn_expr); + + let agg = AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![agg_physical.clone()], + vec![None], + Arc::clone(input), + input_schema.clone(), + ) + .ok()?; + + Some(Arc::new(agg)) + } + fn transform_sort(plan: &Arc) -> Option> { let sort = plan.as_any().downcast_ref::()?; + // Try TopK(fetch=1) to MIN/MAX optimization first + if let Some(optimized) = Self::try_convert_topk_to_minmax(sort) { + return Some(optimized); + } + let children = sort.children(); let child = children.into_iter().exactly_one().ok()?; let order = sort.properties().output_ordering()?; let order = order.iter().exactly_one().ok()?; let order_desc = order.options.descending; - let order = order.expr.as_any().downcast_ref::()?; + let order = order.expr.as_any().downcast_ref::()?; let mut cur_col_name = order.name().to_string(); let limit = sort.fetch()?; @@ -111,7 +201,8 @@ impl TopKAggregation { } else if let Some(proj) = plan.as_any().downcast_ref::() { // track renames due to successive projections for proj_expr in proj.expr() { - let Some(src_col) = proj_expr.expr.as_any().downcast_ref::() + let Some(src_col) = + proj_expr.expr.as_any().downcast_ref::() else { continue; };