Skip to content

Conversation

@Yicong-Huang
Copy link
Contributor

@Yicong-Huang Yicong-Huang commented Dec 9, 2025

What changes were proposed in this pull request?

Add support for User-Defined Aggregate Functions (UDAF) in PySpark. Currently PySpark supports User-Defined Functions (UDF) and User-Defined Table Functions (UDTF), but lacks support for UDAF. Users need to write custom aggregation logic in Scala/Java or use less efficient workarounds.

This change adds UDAF support using a two-stage aggregation pattern with mapInArrow and applyInArrow. The basic idea is to implement aggregation (and partial aggregation) by:

df.selectExpr("rand() as key").mapInArrow(reduce).groupBy(key).applyInArrow(merge)

Where func1 calls Aggregator.reduce() for partial aggregation within each partition, and func2 calls Aggregator.merge() to combine partial results, then Aggregator.finish() for final results.

Aligned with Scala side, the implementation provides a Python Aggregator base class that users can subclass:

class Aggregator:
    def zero(self) -> Any:
        """Return zero value for aggregation buffer"""
        raise NotImplementedError
    
    def reduce(self, buffer: Any, value: Any) -> Any:
        """Combine input value into buffer"""
        raise NotImplementedError
    
    def merge(self, buffer1: Any, buffer2: Any) -> Any:
        """Merge two intermediate buffers"""
        raise NotImplementedError
    
    def finish(self, reduction: Any) -> Any:
        """Produce final result from buffer"""
        raise NotImplementedError

Users can create UDAF instances using the udaf() function and use them with DataFrame.agg():

sum_udaf = udaf(MySum(), "bigint")
df.agg(sum_udaf(df.value))
df.groupBy("group").agg(sum_udaf(df.value))

Key changes:

  • Added pyspark.sql.udaf module with Aggregator base class, UserDefinedAggregateFunction wrapper, and udaf() factory function
  • Integrated UDAF support in GroupedData.agg() by detecting UDAF columns via _udaf_func attribute

Why are the changes needed?

Currently PySpark lacks support for User-Defined Aggregate Functions (UDAF), which limits users' ability to express complex aggregation logic directly in Python. Users must either write custom aggregation logic in Scala/Java or use less efficient workarounds. This change adds UDAF support to complement existing UDF and UDTF support in PySpark, aligning with the Scala/Java Aggregator interface in org.apache.spark.sql.expressions.Aggregator.

Does this PR introduce any user-facing change?

Yes. This PR adds a new feature - User-Defined Aggregate Functions (UDAF) support in PySpark. Users can now define custom aggregation logic by subclassing the Aggregator class and using the udaf() function to create UDAF instances that can be used with DataFrame.agg() and GroupedData.agg().

Example:

class MySum(Aggregator):
    def zero(self):
        return 0
    def reduce(self, buffer, value):
        return buffer + value
    def merge(self, buffer1, buffer2):
        return buffer1 + buffer2
    def finish(self, reduction):
        return reduction

sum_udaf = udaf(MySum(), "bigint")
df.agg(sum_udaf(df.value))

How was this patch tested?

Added comprehensive unit tests in python/pyspark/sql/tests/test_udaf.py covering:

  • Basic aggregation (sum, average, max)
  • Grouped aggregation with groupBy().agg()
  • Null value handling
  • Empty DataFrame handling
  • Large datasets (20000+ rows) distributed across partitions
  • Error handling for invalid inputs
  • Integration with df.agg() and df.groupBy().agg()

Was this patch authored or co-authored using generative AI tooling?

No.

Copy link
Contributor

@allisonwang-db allisonwang-db left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice feature!

def test_udaf_mixed_with_other_agg_not_supported(self):
"""Test that mixing UDAF with other aggregate functions raises error."""

class MySum(Aggregator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some tests for more complicated data structures? like dictionary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added more data types!

@Yicong-Huang Yicong-Huang changed the title [SPARK-54647][PYTHON] Support User-Defined Aggregate Functions (UDAF) [WIP][SPARK-54647][PYTHON] Support User-Defined Aggregate Functions (UDAF) Dec 10, 2025
@Yicong-Huang Yicong-Huang marked this pull request as draft December 10, 2025 01:36
]


class Aggregator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we necessarily need this class?
I see UDTF doesn't need a base class.

    >>> class TestUDTF:
    ...     def eval(self, *args: Any):
    ...         yield "hello", "world"

Apply this UDAF to the given columns.

This creates a Column expression that can be used in DataFrame operations.
The actual aggregation is performed using mapInArrow and applyInArrow.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not a dedicated pyhsical plan?

-----
This implementation uses mapInArrow and applyInArrow internally to perform
the aggregation. The approach follows:
1. mapInArrow: Performs partial aggregation (reduce) on each partition
Copy link
Contributor

@zhengruifeng zhengruifeng Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to support partial aggregation with existing arrow UDFs, I think we should use a modified FlatMapGroupsInArrowExec with requiredChildDistribution = UnspecifiedDistribution.

* MapInArrow, Aggregate, and FlatMapGroupsInArrow operators.
*
* This implements a three-phase aggregation pattern:
* 1. Partial aggregation (MapInArrow): Applies reduce() on each partition, outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MapInArrowExec dosen't requiredChildOrdering, where does it sort the data for partial aggregation?

@zhengruifeng
Copy link
Contributor

The basic idea is to implement aggregation (and partial aggregation) by:
df.selectExpr("rand() as key").mapInArrow(reduce).groupBy(key).applyInArrow(merge)

I think there should be a sortWithinPartitions before mapInArrow for partial aggregation.

@zhengruifeng
Copy link
Contributor

The whole approach is based on mapInArrow and applyInArrow, how does it support function registration so that it can be used in SQL?

group_buffers[grouping_key] = agg.zero()

if value is not None:
group_buffers[grouping_key] = agg.reduce(group_buffers[grouping_key], value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_buffers buffers all the aggregators within a partition, it will cause memory issue if the cardinality is large.
A reasonable physical plan should sort the partition by the key, and then output the partial aggregation result after finishing each group

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it mimic the HashAggregateExec, while SortAggregateExec is more stable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants