Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/data/api/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ compute aggregations.
AbsMax
Quantile
Unique
AsList
CountDistinct
ValueCounter
MissingValuePercentage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _pa_is_in(left: Any, right: Any) -> Any:
Operation.SUB: operator.sub,
Operation.MUL: operator.mul,
Operation.DIV: operator.truediv,
Operation.MOD: operator.mod,
Operation.FLOORDIV: operator.floordiv,
Operation.GT: operator.gt,
Operation.LT: operator.lt,
Expand Down Expand Up @@ -128,6 +129,11 @@ def _pa_add_or_concat(left: Any, right: Any) -> Any:
Operation.SUB: pc.subtract,
Operation.MUL: pc.multiply,
Operation.DIV: pc.divide,
Operation.MOD: lambda left, right: (
# Modulo op is essentially:
# r = N - floor(N/M) * M
pc.subtract(left, pc.multiply(pc.floor(pc.divide(left, right)), right))
),
Operation.FLOORDIV: lambda left, right: pc.floor(pc.divide(left, right)),
Operation.GT: pc.greater,
Operation.LT: pc.less,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Operation.SUB: "-",
Operation.MUL: "*",
Operation.DIV: "/",
Operation.MOD: "%",
Operation.FLOORDIV: "//",
Operation.GT: ">",
Operation.LT: "<",
Expand Down
15 changes: 6 additions & 9 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,15 +1716,12 @@ def rows_same(actual: pd.DataFrame, expected: pd.DataFrame) -> bool:
if len(actual) == 0:
return True

try:
pd.testing.assert_frame_equal(
_sort_df(actual).reset_index(drop=True),
_sort_df(expected).reset_index(drop=True),
check_dtype=False,
)
return True
except AssertionError:
return False
pd.testing.assert_frame_equal(
_sort_df(actual).reset_index(drop=True),
_sort_df(expected).reset_index(drop=True),
check_dtype=False,
)
return True


def merge_resources_to_ray_remote_args(
Expand Down
85 changes: 85 additions & 0 deletions python/ray/data/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,28 @@ class AggregateFnV2(AggregateFn, abc.ABC, Generic[AccumulatorType, AggOutputType
4. **Finalization**: Optionally, the `finalize` method transforms the
final combined accumulator into the desired output format.

Generic Type Parameters:
This class is parameterized by two type variables:

- ``AccumulatorType``: The type of the intermediate state (accumulator) used
during aggregation. This is what `aggregate_block` returns, what `combine`
takes as inputs and returns, and what `finalize` receives. For simple
aggregations like `Sum`, this might just be a numeric type. For more complex
aggregations like `Mean`, this could be a composite type like
``List[Union[int, float]]`` representing ``[sum, count]``.

- ``AggOutputType``: The type of the final result after `finalize` is called.
This is what gets written to the output dataset. For `Sum`, this is the
same as the accumulator type (a number). For `Mean`, the accumulator is
``[sum, count]`` but the output is a single ``float`` (the computed mean).

Examples of type parameterization in built-in aggregations::

Count(AggregateFnV2[int, int]) # accumulator: int, output: int
Sum(AggregateFnV2[Union[int, float], ...]) # accumulator: number, output: number
Mean(AggregateFnV2[List[...], float]) # accumulator: [sum, count], output: float
Std(AggregateFnV2[List[...], float]) # accumulator: [M2, mean, count], output: float

Args:
name: The name of the aggregation. This will be used as the column name
in the output, e.g., "sum(my_col)".
Expand Down Expand Up @@ -375,6 +397,69 @@ def combine(self, current_accumulator: int, new: int) -> int:
return current_accumulator + new


@PublicAPI
class AsList(AggregateFnV2[List, List]):
"""Listing aggregation combining all values within the group into a single
list element.

Example:

.. testcode::
:skipif: True

# Skip testing b/c this example require proper ordering of the output
# to be robust and not flaky

import ray
from ray.data.aggregate import AsList

ds = ray.data.range(10)
# Schema: {'id': int64}
ds = ds.add_column("group_key", lambda x: x % 3)
# Schema: {'id': int64, 'group_key': int64}

# Listing all elements per group:
result = ds.groupby("group_key").aggregate(AsList(on="id")).take_all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add sort for testing?

Suggested change
result = ds.groupby("group_key").aggregate(AsList(on="id")).take_all()
result = ds.groupby("group_key").aggregate(AsList(on="id")).sort("group_key").take_all()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not sufficient -- we need to order the list too which makes this code example really clumsy (hence why i'm skipping testing it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enabliling preserve_order should solve this.

Copy link
Contributor Author

@alexeykudinkin alexeykudinkin Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preserve order won't guarantee sequence order either b/c it depends on the order of arrival of the shards into HashShuffleAggregator

# result: [{'group_key': 0, 'list(id)': [0, 3, 6, 9]},
# {'group_key': 1, 'list(id)': [1, 4, 7]},
# {'group_key': 2, 'list(id)': [2, 5, 8]}

Args:
on: The name of the column to collect values from. Must be provided.
alias_name: Optional name for the resulting column.
ignore_nulls: Whether to ignore null values when collecting. If `True`,
nulls are skipped. If `False` (default), nulls are included in the list.
"""

def __init__(
self,
on: str,
alias_name: Optional[str] = None,
ignore_nulls: bool = False,
):
super().__init__(
alias_name if alias_name else f"list({on or ''})",
on=on,
ignore_nulls=ignore_nulls,
zero_factory=lambda: [],
)

def aggregate_block(self, block: Block) -> AccumulatorType:
column_accessor = BlockColumnAccessor.for_column(
block[self.get_target_column()]
)

if self._ignore_nulls:
column_accessor = BlockColumnAccessor.for_column(column_accessor.dropna())

return column_accessor.to_pylist()

def combine(
self, current_accumulator: AccumulatorType, new: AccumulatorType
) -> AccumulatorType:
return current_accumulator + new


@PublicAPI
class Sum(AggregateFnV2[Union[int, float], Union[int, float]]):
"""Defines sum aggregation.
Expand Down
14 changes: 13 additions & 1 deletion python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Operation(Enum):
SUB = "sub"
MUL = "mul"
DIV = "div"
MOD = "mod"
FLOORDIV = "floordiv"
GT = "gt"
LT = "lt"
Expand Down Expand Up @@ -299,7 +300,10 @@ def _bin(self, other: Any, op: Operation) -> "Expr":
other = LiteralExpr(other)
return BinaryExpr(op, self, other)

# arithmetic
#
# Arithmetic ops
#

def __add__(self, other: Any) -> "Expr":
"""Addition operator (+)."""
return self._bin(other, Operation.ADD)
Expand All @@ -324,6 +328,14 @@ def __rmul__(self, other: Any) -> "Expr":
"""Reverse multiplication operator (for literal * expr)."""
return LiteralExpr(other)._bin(self, Operation.MUL)

def __mod__(self, other: Any):
"""Modulation operator (%)."""
return self._bin(other, Operation.MOD)

def __rmod__(self, other: Any):
"""Modulation operator (%)."""
return LiteralExpr(other)._bin(self, Operation.MOD)

def __truediv__(self, other: Any) -> "Expr":
"""Division operator (/)."""
return self._bin(other, Operation.DIV)
Expand Down
76 changes: 76 additions & 0 deletions python/ray/data/tests/test_groupby_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ray.data.aggregate import (
AbsMax,
AggregateFn,
AsList,
Count,
CountDistinct,
Max,
Expand All @@ -32,6 +33,7 @@
)
from ray.data.block import BlockAccessor
from ray.data.context import DataContext, ShuffleStrategy
from ray.data.expressions import col
from ray.data.tests.conftest import * # noqa
from ray.data.tests.util import named_values
from ray.tests.conftest import * # noqa
Expand Down Expand Up @@ -468,6 +470,80 @@ def _to_batch_format(ds):
)


@pytest.mark.parametrize("num_parts", [1, 10])
@pytest.mark.parametrize("batch_format", ["pandas", "pyarrow"])
def test_as_list_e2e(
ray_start_regular_shared_2_cpus,
batch_format,
num_parts,
disable_fallback_to_object_extension,
):
ds = ray.data.range(10)
ds = ds.with_column("group_key", col("id") % 3).repartition(num_parts)

# Listing all elements per group:
result = ds.groupby("group_key").aggregate(AsList(on="id")).take_all()

for i in range(len(result)):
result[i]["list(id)"] = sorted(result[i]["list(id)"])

assert sorted(result, key=lambda x: x["group_key"]) == [
{"group_key": 0, "list(id)": [0, 3, 6, 9]},
{"group_key": 1, "list(id)": [1, 4, 7]},
{"group_key": 2, "list(id)": [2, 5, 8]},
]


@pytest.mark.parametrize("num_parts", [1, 10])
@pytest.mark.parametrize("batch_format", ["pandas", "pyarrow"])
def test_as_list_with_nulls(
ray_start_regular_shared_2_cpus,
batch_format,
num_parts,
disable_fallback_to_object_extension,
):
# Test with nulls included (default behavior: ignore_nulls=False)
ds = ray.data.from_items(
[
{"group": "A", "value": 1},
{"group": "A", "value": None},
{"group": "A", "value": 3},
{"group": "B", "value": None},
{"group": "B", "value": 5},
]
).repartition(num_parts)

# Default: nulls are included in the list
result = ds.groupby("group").aggregate(AsList(on="value")).take_all()
result_sorted = sorted(result, key=lambda x: x["group"])

# Sort the lists for comparison (None values will be at the end in sorted order)
for r in result_sorted:
# Separate None and non-None values for sorting
non_nulls = sorted([v for v in r["list(value)"] if v is not None])
nulls = [v for v in r["list(value)"] if v is None]
r["list(value)"] = non_nulls + nulls

assert result_sorted == [
{"group": "A", "list(value)": [1, 3, None]},
{"group": "B", "list(value)": [5, None]},
]

# With ignore_nulls=True: nulls are excluded from the list
result = (
ds.groupby("group").aggregate(AsList(on="value", ignore_nulls=True)).take_all()
)
result_sorted = sorted(result, key=lambda x: x["group"])

for r in result_sorted:
r["list(value)"] = sorted(r["list(value)"])

assert result_sorted == [
{"group": "A", "list(value)": [1, 3]},
{"group": "B", "list(value)": [5]},
]


@pytest.mark.parametrize("num_parts", [1, 30])
@pytest.mark.parametrize("ds_format", ["pandas", "pyarrow"])
def test_groupby_arrow_multi_agg(
Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,11 @@ def test_merge_resources_to_ray_remote_args():
],
)
def test_rows_same(actual: pd.DataFrame, expected: pd.DataFrame, expected_equal: bool):
assert rows_same(actual, expected) == expected_equal
if expected_equal:
assert rows_same(actual, expected)
else:
with pytest.raises(AssertionError):
assert rows_same(actual, expected)


if __name__ == "__main__":
Expand Down
22 changes: 22 additions & 0 deletions python/ray/data/tests/unit/expressions/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,28 @@ def test_reverse_floor_division(self, sample_data):
result.reset_index(drop=True), expected, check_names=False
)

# ── Modulo ──

@pytest.mark.parametrize(
"expr,expected_values",
[
(col("a") % 3, [1, 2, 0, 1]),
(col("a") % col("c"), [1.0, 0.0, 2.0, 4.0]),
(10 % col("b"), [0, 2, 0, 2]),
],
ids=["col_mod_int", "col_mod_fp", "col_rmod_int"],
)
def test_modulo(self, sample_data, expr, expected_values):
"""Test modulo operations."""
assert isinstance(expr, BinaryExpr)
assert expr.op == Operation.MOD
result = eval_expr(expr, sample_data)
pd.testing.assert_series_equal(
result.reset_index(drop=True),
pd.Series(expected_values, name=None),
check_names=False,
)


# ──────────────────────────────────────
# Complex Arithmetic Expressions
Expand Down