Skip to content
5 changes: 4 additions & 1 deletion python/pyspark/sql/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@
"user",
# "uuid": Excluded because of the name conflict with builtin uuid module
"version",
# UDF, UDTF and UDT
# UDF, UDAF, UDTF and UDT
"Aggregator",
"AnalyzeArgument",
"AnalyzeResult",
"ArrowUDFType",
Expand All @@ -544,6 +545,7 @@
"SelectedColumn",
"SkipRestOfInputTableException",
"UserDefinedFunction",
"UserDefinedAggregateFunction",
"UserDefinedTableFunction",
"arrow_udf",
# Geospatial ST Functions
Expand All @@ -556,6 +558,7 @@
"call_udf",
"pandas_udf",
"udf",
"udaf",
"udtf",
"arrow_udtf",
"unwrap_udt",
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401
from pyspark.sql.udaf import Aggregator, UserDefinedAggregateFunction, udaf # noqa: F401
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401
from pyspark.sql.udtf import OrderingColumn, PartitioningColumn, SelectedColumn # noqa: F401
from pyspark.sql.udtf import SkipRestOfInputTableException # noqa: F401
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,17 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
exprs = cast(Tuple[Column, ...], exprs)

# Check if any column is a UDAF column (has _udaf_func attribute)
from pyspark.sql.udaf import _handle_udaf_aggregation_in_grouped_data

udaf_cols = [c for c in exprs if hasattr(c, "_udaf_func")]
if udaf_cols:
return _handle_udaf_aggregation_in_grouped_data(
self._df, self._jgd, exprs, udaf_cols
)

# Normal column aggregation
jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.session._sc, [c._jc for c in exprs[1:]]))
return DataFrame(jdf, self.session)

Expand Down
Loading