Skip to content

Commit 2151741

Browse files
bubulalabujizezhang
authored andcommitted
feat: support named arguments for aggregate and window udfs (apache#18389)
## Which issue does this PR close? Addresses portions of apache#17379. ## Rationale for this change Add support for aggregate and window UDFs in the same way as we did it for scalar UDFs here: apache#18019 ## Are these changes tested? Yes ## Are there any user-facing changes? Yes, the changes are user-facing, documented, purely additive and non-breaking.
1 parent 312d5c8 commit 2151741

File tree

6 files changed

+210
-42
lines changed

6 files changed

+210
-42
lines changed

datafusion/functions-aggregate/src/correlation.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ impl Correlation {
8888
signature: Signature::exact(
8989
vec![DataType::Float64, DataType::Float64],
9090
Volatility::Immutable,
91-
),
91+
)
92+
.with_parameter_names(vec!["y".to_string(), "x".to_string()])
93+
.expect("valid parameter names for corr"),
9294
}
9395
}
9496
}

datafusion/functions-aggregate/src/percentile_cont.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ impl PercentileCont {
146146
variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
147147
}
148148
Self {
149-
signature: Signature::one_of(variants, Volatility::Immutable),
149+
signature: Signature::one_of(variants, Volatility::Immutable)
150+
.with_parameter_names(vec!["expr".to_string(), "percentile".to_string()])
151+
.expect("valid parameter names for percentile_cont"),
150152
aliases: vec![String::from("quantile_cont")],
151153
}
152154
}

datafusion/functions-window/src/lead_lag.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,13 @@ impl WindowShift {
137137
TypeSignature::Any(3),
138138
],
139139
Volatility::Immutable,
140-
),
140+
)
141+
.with_parameter_names(vec![
142+
"expr".to_string(),
143+
"offset".to_string(),
144+
"default".to_string(),
145+
])
146+
.expect("valid parameter names for lead/lag"),
141147
kind,
142148
}
143149
}

datafusion/sql/src/expr/function.rs

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,30 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
386386
};
387387

388388
if let Ok(fun) = self.find_window_func(&name) {
389-
let args = self.function_args_to_expr(args, schema, planner_context)?;
389+
let (args, arg_names) =
390+
self.function_args_to_expr_with_names(args, schema, planner_context)?;
391+
392+
let resolved_args = if arg_names.iter().any(|name| name.is_some()) {
393+
let signature = match &fun {
394+
WindowFunctionDefinition::AggregateUDF(udaf) => udaf.signature(),
395+
WindowFunctionDefinition::WindowUDF(udwf) => udwf.signature(),
396+
};
397+
398+
if let Some(param_names) = &signature.parameter_names {
399+
datafusion_expr::arguments::resolve_function_arguments(
400+
param_names,
401+
args,
402+
arg_names,
403+
)?
404+
} else {
405+
return plan_err!(
406+
"Window function '{}' does not support named arguments",
407+
name
408+
);
409+
}
410+
} else {
411+
args
412+
};
390413

391414
// Plan FILTER clause if present
392415
let filter = filter
@@ -396,7 +419,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
396419

397420
let mut window_expr = RawWindowExpr {
398421
func_def: fun,
399-
args,
422+
args: resolved_args,
400423
partition_by,
401424
order_by,
402425
window_frame,
@@ -464,8 +487,8 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
464487
);
465488
}
466489

467-
let mut args =
468-
self.function_args_to_expr(args, schema, planner_context)?;
490+
let (mut args, mut arg_names) =
491+
self.function_args_to_expr_with_names(args, schema, planner_context)?;
469492

470493
let order_by = if fm.supports_within_group_clause() {
471494
let within_group = self.order_by_to_sort_expr(
@@ -479,6 +502,12 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
479502
// Add the WITHIN GROUP ordering expressions to the front of the argument list
480503
// So function(arg) WITHIN GROUP (ORDER BY x) becomes function(x, arg)
481504
if !within_group.is_empty() {
505+
// Prepend None arg names for each WITHIN GROUP expression
506+
let within_group_count = within_group.len();
507+
arg_names = std::iter::repeat_n(None, within_group_count)
508+
.chain(arg_names)
509+
.collect();
510+
482511
args = within_group
483512
.iter()
484513
.map(|sort| sort.expr.clone())
@@ -506,9 +535,26 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
506535
.transpose()?
507536
.map(Box::new);
508537

538+
let resolved_args = if arg_names.iter().any(|name| name.is_some()) {
539+
if let Some(param_names) = &fm.signature().parameter_names {
540+
datafusion_expr::arguments::resolve_function_arguments(
541+
param_names,
542+
args,
543+
arg_names,
544+
)?
545+
} else {
546+
return plan_err!(
547+
"Aggregate function '{}' does not support named arguments",
548+
fm.name()
549+
);
550+
}
551+
} else {
552+
args
553+
};
554+
509555
let mut aggregate_expr = RawAggregateExpr {
510556
func: fm,
511-
args,
557+
args: resolved_args,
512558
distinct,
513559
filter,
514560
order_by,

datafusion/sqllogictest/test_files/named_arguments.slt

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,135 @@ SELECT substr(str => 'hello world', start_pos => 7, length => 5);
137137
# Reset to default dialect
138138
statement ok
139139
set datafusion.sql_parser.dialect = 'Generic';
140+
141+
#############
142+
## Aggregate UDF Tests - using corr(y, x) function
143+
#############
144+
145+
# Setup test data
146+
statement ok
147+
CREATE TABLE correlation_test(col1 DOUBLE, col2 DOUBLE) AS VALUES
148+
(1.0, 2.0),
149+
(2.0, 4.0),
150+
(3.0, 6.0),
151+
(4.0, 8.0);
152+
153+
# Test positional arguments (baseline)
154+
query R
155+
SELECT corr(col1, col2) FROM correlation_test;
156+
----
157+
1
158+
159+
# Test named arguments out of order (proves named args work for aggregates)
160+
query R
161+
SELECT corr(x => col2, y => col1) FROM correlation_test;
162+
----
163+
1
164+
165+
# Error: function doesn't support named arguments (count has no parameter names)
166+
query error DataFusion error: Error during planning: Aggregate function 'count' does not support named arguments
167+
SELECT count(value => col1) FROM correlation_test;
168+
169+
# Cleanup
170+
statement ok
171+
DROP TABLE correlation_test;
172+
173+
#############
174+
## Aggregate UDF with WITHIN GROUP Tests - using percentile_cont(expression, percentile)
175+
## This tests the special handling where WITHIN GROUP ORDER BY expressions are prepended to args
176+
#############
177+
178+
# Setup test data
179+
statement ok
180+
CREATE TABLE percentile_test(salary DOUBLE) AS VALUES
181+
(50000.0),
182+
(60000.0),
183+
(70000.0),
184+
(80000.0),
185+
(90000.0);
186+
187+
# Test positional arguments (baseline) - standard call without WITHIN GROUP
188+
query R
189+
SELECT percentile_cont(salary, 0.5) FROM percentile_test;
190+
----
191+
70000
192+
193+
# Test WITHIN GROUP with positional argument
194+
query R
195+
SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY salary) FROM percentile_test;
196+
----
197+
70000
198+
199+
# Test WITHIN GROUP with named argument for percentile
200+
# The ORDER BY expression (salary) is prepended internally, becoming: percentile_cont(salary, 0.5)
201+
# We use named argument for percentile, which should work correctly
202+
query R
203+
SELECT percentile_cont(percentile => 0.5) WITHIN GROUP (ORDER BY salary) FROM percentile_test;
204+
----
205+
70000
206+
207+
# Verify the WITHIN GROUP prepending logic with different percentile value
208+
query R
209+
SELECT percentile_cont(percentile => 0.25) WITHIN GROUP (ORDER BY salary) FROM percentile_test;
210+
----
211+
60000
212+
213+
# Cleanup
214+
statement ok
215+
DROP TABLE percentile_test;
216+
217+
#############
218+
## Window UDF Tests - using lead(expression, offset, default) function
219+
#############
220+
221+
# Setup test data
222+
statement ok
223+
CREATE TABLE window_test(id INT, value INT) AS VALUES
224+
(1, 10),
225+
(2, 20),
226+
(3, 30),
227+
(4, 40);
228+
229+
# Test positional arguments (baseline)
230+
query II
231+
SELECT id, lead(value, 1, 0) OVER (ORDER BY id) FROM window_test ORDER BY id;
232+
----
233+
1 20
234+
2 30
235+
3 40
236+
4 0
237+
238+
# Test named arguments out of order (proves named args work for window functions)
239+
query II
240+
SELECT id, lead(default => 0, offset => 1, expr => value) OVER (ORDER BY id) FROM window_test ORDER BY id;
241+
----
242+
1 20
243+
2 30
244+
3 40
245+
4 0
246+
247+
# Test with 1 argument (offset and default use defaults)
248+
query II
249+
SELECT id, lead(expr => value) OVER (ORDER BY id) FROM window_test ORDER BY id;
250+
----
251+
1 20
252+
2 30
253+
3 40
254+
4 NULL
255+
256+
# Test with 2 arguments (default uses default)
257+
query II
258+
SELECT id, lead(expr => value, offset => 2) OVER (ORDER BY id) FROM window_test ORDER BY id;
259+
----
260+
1 30
261+
2 40
262+
3 NULL
263+
4 NULL
264+
265+
# Error: function doesn't support named arguments (row_number has no parameter names)
266+
query error DataFusion error: Error during planning: Window function 'row_number' does not support named arguments
267+
SELECT row_number(value => 1) OVER (ORDER BY id) FROM window_test;
268+
269+
# Cleanup
270+
statement ok
271+
DROP TABLE window_test;

docs/source/library-user-guide/functions/adding-udfs.md

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -588,10 +588,17 @@ For async UDF implementation details, see [`async_udf.rs`](https://github.com/ap
588588

589589
## Named Arguments
590590

591-
DataFusion supports PostgreSQL-style named arguments for scalar functions, allowing you to pass arguments by parameter name:
591+
DataFusion supports named arguments for Scalar, Window, and Aggregate UDFs, allowing you to pass arguments by parameter name:
592592

593593
```sql
594+
-- Scalar function
594595
SELECT substr(str => 'hello', start_pos => 2, length => 3);
596+
597+
-- Window function
598+
SELECT lead(expr => value, offset => 1) OVER (ORDER BY id) FROM table;
599+
600+
-- Aggregate function
601+
SELECT corr(y => col1, x => col2) FROM table;
595602
```
596603

597604
Named arguments can be mixed with positional arguments, but positional arguments must come first:
@@ -602,38 +609,7 @@ SELECT substr('hello', start_pos => 2, length => 3); -- Valid
602609

603610
### Implementing Functions with Named Arguments
604611

605-
To support named arguments in your UDF, add parameter names to your function's signature using `.with_parameter_names()`:
606-
607-
```rust
608-
# use arrow::datatypes::DataType;
609-
# use datafusion_expr::{Signature, Volatility};
610-
#
611-
# #[derive(Debug)]
612-
# struct MyFunction {
613-
# signature: Signature,
614-
# }
615-
#
616-
impl MyFunction {
617-
fn new() -> Self {
618-
Self {
619-
signature: Signature::uniform(
620-
2,
621-
vec![DataType::Float64],
622-
Volatility::Immutable
623-
)
624-
.with_parameter_names(vec![
625-
"base".to_string(),
626-
"exponent".to_string()
627-
])
628-
.expect("valid parameter names"),
629-
}
630-
}
631-
}
632-
```
633-
634-
The parameter names should match the order of arguments in your function's signature. DataFusion automatically resolves named arguments to the correct positional order before invoking your function.
635-
636-
### Example
612+
To support named arguments in your UDF, add parameter names to your function's signature using `.with_parameter_names()`. This works the same way for Scalar, Window, and Aggregate UDFs:
637613

638614
```rust
639615
# use std::sync::Arc;
@@ -681,10 +657,14 @@ impl ScalarUDFImpl for PowerFunction {
681657
}
682658
```
683659

684-
Once registered, users can call your function with named arguments:
660+
The parameter names should match the order of arguments in your function's signature. DataFusion automatically resolves named arguments to the correct positional order before invoking your function.
661+
662+
Once registered, users can call your functions with named arguments in any order:
685663

686664
```sql
665+
-- All equivalent
687666
SELECT power(base => 2.0, exponent => 3.0);
667+
SELECT power(exponent => 3.0, base => 2.0);
688668
SELECT power(2.0, exponent => 3.0);
689669
```
690670

0 commit comments

Comments
 (0)