-
Notifications
You must be signed in to change notification settings - Fork 29k
[WIP][SPARK-54647][PYTHON] Support User-Defined Aggregate Functions (UDAF) #53400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP][SPARK-54647][PYTHON] Support User-Defined Aggregate Functions (UDAF) #53400
Conversation
allisonwang-db
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice feature!
| def test_udaf_mixed_with_other_agg_not_supported(self): | ||
| """Test that mixing UDAF with other aggregate functions raises error.""" | ||
|
|
||
| class MySum(Aggregator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some tests for more complicated data structures? like dictionary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added more data types!
| ] | ||
|
|
||
|
|
||
| class Aggregator: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we necessarily need this class?
I see UDTF doesn't need a base class.
>>> class TestUDTF:
... def eval(self, *args: Any):
... yield "hello", "world"
| Apply this UDAF to the given columns. | ||
|
|
||
| This creates a Column expression that can be used in DataFrame operations. | ||
| The actual aggregation is performed using mapInArrow and applyInArrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not a dedicated pyhsical plan?
| ----- | ||
| This implementation uses mapInArrow and applyInArrow internally to perform | ||
| the aggregation. The approach follows: | ||
| 1. mapInArrow: Performs partial aggregation (reduce) on each partition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to support partial aggregation with existing arrow UDFs, I think we should use a modified FlatMapGroupsInArrowExec with requiredChildDistribution = UnspecifiedDistribution.
| * MapInArrow, Aggregate, and FlatMapGroupsInArrow operators. | ||
| * | ||
| * This implements a three-phase aggregation pattern: | ||
| * 1. Partial aggregation (MapInArrow): Applies reduce() on each partition, outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MapInArrowExec dosen't requiredChildOrdering, where does it sort the data for partial aggregation?
I think there should be a |
|
The whole approach is based on |
| group_buffers[grouping_key] = agg.zero() | ||
|
|
||
| if value is not None: | ||
| group_buffers[grouping_key] = agg.reduce(group_buffers[grouping_key], value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group_buffers buffers all the aggregators within a partition, it will cause memory issue if the cardinality is large.
A reasonable physical plan should sort the partition by the key, and then output the partial aggregation result after finishing each group
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it mimic the HashAggregateExec, while SortAggregateExec is more stable
What changes were proposed in this pull request?
Add support for User-Defined Aggregate Functions (UDAF) in PySpark. Currently PySpark supports User-Defined Functions (UDF) and User-Defined Table Functions (UDTF), but lacks support for UDAF. Users need to write custom aggregation logic in Scala/Java or use less efficient workarounds.
This change adds UDAF support using a two-stage aggregation pattern with
mapInArrowandapplyInArrow. The basic idea is to implement aggregation (and partial aggregation) by:Where
func1callsAggregator.reduce()for partial aggregation within each partition, andfunc2callsAggregator.merge()to combine partial results, thenAggregator.finish()for final results.Aligned with Scala side, the implementation provides a Python
Aggregatorbase class that users can subclass:Users can create UDAF instances using the
udaf()function and use them withDataFrame.agg():Key changes:
pyspark.sql.udafmodule withAggregatorbase class,UserDefinedAggregateFunctionwrapper, andudaf()factory functionGroupedData.agg()by detecting UDAF columns via_udaf_funcattributeWhy are the changes needed?
Currently PySpark lacks support for User-Defined Aggregate Functions (UDAF), which limits users' ability to express complex aggregation logic directly in Python. Users must either write custom aggregation logic in Scala/Java or use less efficient workarounds. This change adds UDAF support to complement existing UDF and UDTF support in PySpark, aligning with the Scala/Java
Aggregatorinterface inorg.apache.spark.sql.expressions.Aggregator.Does this PR introduce any user-facing change?
Yes. This PR adds a new feature - User-Defined Aggregate Functions (UDAF) support in PySpark. Users can now define custom aggregation logic by subclassing the
Aggregatorclass and using theudaf()function to create UDAF instances that can be used withDataFrame.agg()andGroupedData.agg().Example:
How was this patch tested?
Added comprehensive unit tests in
python/pyspark/sql/tests/test_udaf.pycovering:groupBy().agg()df.agg()anddf.groupBy().agg()Was this patch authored or co-authored using generative AI tooling?
No.