diff --git a/doc/source/data/api/aggregate.rst b/doc/source/data/api/aggregate.rst index e1f2795baaed..7955fabd965f 100644 --- a/doc/source/data/api/aggregate.rst +++ b/doc/source/data/api/aggregate.rst @@ -25,6 +25,7 @@ compute aggregations. AbsMax Quantile Unique + AsList CountDistinct ValueCounter MissingValuePercentage diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index 7f5137d80485..f15b593da53b 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -44,6 +44,7 @@ def _pa_is_in(left: Any, right: Any) -> Any: Operation.SUB: operator.sub, Operation.MUL: operator.mul, Operation.DIV: operator.truediv, + Operation.MOD: operator.mod, Operation.FLOORDIV: operator.floordiv, Operation.GT: operator.gt, Operation.LT: operator.lt, @@ -128,6 +129,11 @@ def _pa_add_or_concat(left: Any, right: Any) -> Any: Operation.SUB: pc.subtract, Operation.MUL: pc.multiply, Operation.DIV: pc.divide, + Operation.MOD: lambda left, right: ( + # Modulo op is essentially: + # r = N - floor(N/M) * M + pc.subtract(left, pc.multiply(pc.floor(pc.divide(left, right)), right)) + ), Operation.FLOORDIV: lambda left, right: pc.floor(pc.divide(left, right)), Operation.GT: pc.greater, Operation.LT: pc.less, diff --git a/python/ray/data/_internal/planner/plan_expression/expression_visitors.py b/python/ray/data/_internal/planner/plan_expression/expression_visitors.py index 3d2af8c03fcf..a681a902d3be 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_visitors.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_visitors.py @@ -24,6 +24,7 @@ Operation.SUB: "-", Operation.MUL: "*", Operation.DIV: "/", + Operation.MOD: "%", Operation.FLOORDIV: "//", Operation.GT: ">", Operation.LT: "<", diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index 290fac5f7ebd..aca398f8b980 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -1716,15 +1716,12 @@ def rows_same(actual: pd.DataFrame, expected: pd.DataFrame) -> bool: if len(actual) == 0: return True - try: - pd.testing.assert_frame_equal( - _sort_df(actual).reset_index(drop=True), - _sort_df(expected).reset_index(drop=True), - check_dtype=False, - ) - return True - except AssertionError: - return False + pd.testing.assert_frame_equal( + _sort_df(actual).reset_index(drop=True), + _sort_df(expected).reset_index(drop=True), + check_dtype=False, + ) + return True def merge_resources_to_ray_remote_args( diff --git a/python/ray/data/aggregate.py b/python/ray/data/aggregate.py index 3ea70ff940fc..c616d396704a 100644 --- a/python/ray/data/aggregate.py +++ b/python/ray/data/aggregate.py @@ -174,6 +174,28 @@ class AggregateFnV2(AggregateFn, abc.ABC, Generic[AccumulatorType, AggOutputType 4. **Finalization**: Optionally, the `finalize` method transforms the final combined accumulator into the desired output format. + Generic Type Parameters: + This class is parameterized by two type variables: + + - ``AccumulatorType``: The type of the intermediate state (accumulator) used + during aggregation. This is what `aggregate_block` returns, what `combine` + takes as inputs and returns, and what `finalize` receives. For simple + aggregations like `Sum`, this might just be a numeric type. For more complex + aggregations like `Mean`, this could be a composite type like + ``List[Union[int, float]]`` representing ``[sum, count]``. + + - ``AggOutputType``: The type of the final result after `finalize` is called. + This is what gets written to the output dataset. For `Sum`, this is the + same as the accumulator type (a number). For `Mean`, the accumulator is + ``[sum, count]`` but the output is a single ``float`` (the computed mean). + + Examples of type parameterization in built-in aggregations:: + + Count(AggregateFnV2[int, int]) # accumulator: int, output: int + Sum(AggregateFnV2[Union[int, float], ...]) # accumulator: number, output: number + Mean(AggregateFnV2[List[...], float]) # accumulator: [sum, count], output: float + Std(AggregateFnV2[List[...], float]) # accumulator: [M2, mean, count], output: float + Args: name: The name of the aggregation. This will be used as the column name in the output, e.g., "sum(my_col)". @@ -375,6 +397,69 @@ def combine(self, current_accumulator: int, new: int) -> int: return current_accumulator + new +@PublicAPI +class AsList(AggregateFnV2[List, List]): + """Listing aggregation combining all values within the group into a single + list element. + + Example: + + .. testcode:: + :skipif: True + + # Skip testing b/c this example require proper ordering of the output + # to be robust and not flaky + + import ray + from ray.data.aggregate import AsList + + ds = ray.data.range(10) + # Schema: {'id': int64} + ds = ds.add_column("group_key", lambda x: x % 3) + # Schema: {'id': int64, 'group_key': int64} + + # Listing all elements per group: + result = ds.groupby("group_key").aggregate(AsList(on="id")).take_all() + # result: [{'group_key': 0, 'list(id)': [0, 3, 6, 9]}, + # {'group_key': 1, 'list(id)': [1, 4, 7]}, + # {'group_key': 2, 'list(id)': [2, 5, 8]} + + Args: + on: The name of the column to collect values from. Must be provided. + alias_name: Optional name for the resulting column. + ignore_nulls: Whether to ignore null values when collecting. If `True`, + nulls are skipped. If `False` (default), nulls are included in the list. + """ + + def __init__( + self, + on: str, + alias_name: Optional[str] = None, + ignore_nulls: bool = False, + ): + super().__init__( + alias_name if alias_name else f"list({on or ''})", + on=on, + ignore_nulls=ignore_nulls, + zero_factory=lambda: [], + ) + + def aggregate_block(self, block: Block) -> AccumulatorType: + column_accessor = BlockColumnAccessor.for_column( + block[self.get_target_column()] + ) + + if self._ignore_nulls: + column_accessor = BlockColumnAccessor.for_column(column_accessor.dropna()) + + return column_accessor.to_pylist() + + def combine( + self, current_accumulator: AccumulatorType, new: AccumulatorType + ) -> AccumulatorType: + return current_accumulator + new + + @PublicAPI class Sum(AggregateFnV2[Union[int, float], Union[int, float]]): """Defines sum aggregation. diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 7983b70120c3..adfb6332d4a8 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -69,6 +69,7 @@ class Operation(Enum): SUB = "sub" MUL = "mul" DIV = "div" + MOD = "mod" FLOORDIV = "floordiv" GT = "gt" LT = "lt" @@ -299,7 +300,10 @@ def _bin(self, other: Any, op: Operation) -> "Expr": other = LiteralExpr(other) return BinaryExpr(op, self, other) - # arithmetic + # + # Arithmetic ops + # + def __add__(self, other: Any) -> "Expr": """Addition operator (+).""" return self._bin(other, Operation.ADD) @@ -324,6 +328,14 @@ def __rmul__(self, other: Any) -> "Expr": """Reverse multiplication operator (for literal * expr).""" return LiteralExpr(other)._bin(self, Operation.MUL) + def __mod__(self, other: Any): + """Modulation operator (%).""" + return self._bin(other, Operation.MOD) + + def __rmod__(self, other: Any): + """Modulation operator (%).""" + return LiteralExpr(other)._bin(self, Operation.MOD) + def __truediv__(self, other: Any) -> "Expr": """Division operator (/).""" return self._bin(other, Operation.DIV) diff --git a/python/ray/data/tests/test_groupby_e2e.py b/python/ray/data/tests/test_groupby_e2e.py index 098a2efc8421..421495e334bd 100644 --- a/python/ray/data/tests/test_groupby_e2e.py +++ b/python/ray/data/tests/test_groupby_e2e.py @@ -20,6 +20,7 @@ from ray.data.aggregate import ( AbsMax, AggregateFn, + AsList, Count, CountDistinct, Max, @@ -32,6 +33,7 @@ ) from ray.data.block import BlockAccessor from ray.data.context import DataContext, ShuffleStrategy +from ray.data.expressions import col from ray.data.tests.conftest import * # noqa from ray.data.tests.util import named_values from ray.tests.conftest import * # noqa @@ -468,6 +470,80 @@ def _to_batch_format(ds): ) +@pytest.mark.parametrize("num_parts", [1, 10]) +@pytest.mark.parametrize("batch_format", ["pandas", "pyarrow"]) +def test_as_list_e2e( + ray_start_regular_shared_2_cpus, + batch_format, + num_parts, + disable_fallback_to_object_extension, +): + ds = ray.data.range(10) + ds = ds.with_column("group_key", col("id") % 3).repartition(num_parts) + + # Listing all elements per group: + result = ds.groupby("group_key").aggregate(AsList(on="id")).take_all() + + for i in range(len(result)): + result[i]["list(id)"] = sorted(result[i]["list(id)"]) + + assert sorted(result, key=lambda x: x["group_key"]) == [ + {"group_key": 0, "list(id)": [0, 3, 6, 9]}, + {"group_key": 1, "list(id)": [1, 4, 7]}, + {"group_key": 2, "list(id)": [2, 5, 8]}, + ] + + +@pytest.mark.parametrize("num_parts", [1, 10]) +@pytest.mark.parametrize("batch_format", ["pandas", "pyarrow"]) +def test_as_list_with_nulls( + ray_start_regular_shared_2_cpus, + batch_format, + num_parts, + disable_fallback_to_object_extension, +): + # Test with nulls included (default behavior: ignore_nulls=False) + ds = ray.data.from_items( + [ + {"group": "A", "value": 1}, + {"group": "A", "value": None}, + {"group": "A", "value": 3}, + {"group": "B", "value": None}, + {"group": "B", "value": 5}, + ] + ).repartition(num_parts) + + # Default: nulls are included in the list + result = ds.groupby("group").aggregate(AsList(on="value")).take_all() + result_sorted = sorted(result, key=lambda x: x["group"]) + + # Sort the lists for comparison (None values will be at the end in sorted order) + for r in result_sorted: + # Separate None and non-None values for sorting + non_nulls = sorted([v for v in r["list(value)"] if v is not None]) + nulls = [v for v in r["list(value)"] if v is None] + r["list(value)"] = non_nulls + nulls + + assert result_sorted == [ + {"group": "A", "list(value)": [1, 3, None]}, + {"group": "B", "list(value)": [5, None]}, + ] + + # With ignore_nulls=True: nulls are excluded from the list + result = ( + ds.groupby("group").aggregate(AsList(on="value", ignore_nulls=True)).take_all() + ) + result_sorted = sorted(result, key=lambda x: x["group"]) + + for r in result_sorted: + r["list(value)"] = sorted(r["list(value)"]) + + assert result_sorted == [ + {"group": "A", "list(value)": [1, 3]}, + {"group": "B", "list(value)": [5]}, + ] + + @pytest.mark.parametrize("num_parts", [1, 30]) @pytest.mark.parametrize("ds_format", ["pandas", "pyarrow"]) def test_groupby_arrow_multi_agg( diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index 31347e7d96fa..b0a62c71671e 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -416,7 +416,11 @@ def test_merge_resources_to_ray_remote_args(): ], ) def test_rows_same(actual: pd.DataFrame, expected: pd.DataFrame, expected_equal: bool): - assert rows_same(actual, expected) == expected_equal + if expected_equal: + assert rows_same(actual, expected) + else: + with pytest.raises(AssertionError): + assert rows_same(actual, expected) if __name__ == "__main__": diff --git a/python/ray/data/tests/unit/expressions/test_arithmetic.py b/python/ray/data/tests/unit/expressions/test_arithmetic.py index 15579954344d..a0ac757d7490 100644 --- a/python/ray/data/tests/unit/expressions/test_arithmetic.py +++ b/python/ray/data/tests/unit/expressions/test_arithmetic.py @@ -204,6 +204,28 @@ def test_reverse_floor_division(self, sample_data): result.reset_index(drop=True), expected, check_names=False ) + # ── Modulo ── + + @pytest.mark.parametrize( + "expr,expected_values", + [ + (col("a") % 3, [1, 2, 0, 1]), + (col("a") % col("c"), [1.0, 0.0, 2.0, 4.0]), + (10 % col("b"), [0, 2, 0, 2]), + ], + ids=["col_mod_int", "col_mod_fp", "col_rmod_int"], + ) + def test_modulo(self, sample_data, expr, expected_values): + """Test modulo operations.""" + assert isinstance(expr, BinaryExpr) + assert expr.op == Operation.MOD + result = eval_expr(expr, sample_data) + pd.testing.assert_series_equal( + result.reset_index(drop=True), + pd.Series(expected_values, name=None), + check_names=False, + ) + # ────────────────────────────────────── # Complex Arithmetic Expressions