-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[Data] Added AsList aggregation
#59920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a423005
6b3ed26
f22ef48
8696815
fbbcd1d
637878d
db54438
8757041
c930012
f9d3005
a2956fa
071a879
72c3fa4
873e077
d9345cd
5cdaa8a
5cc7bcf
06d5caa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -174,6 +174,28 @@ class AggregateFnV2(AggregateFn, abc.ABC, Generic[AccumulatorType, AggOutputType | |||||
| 4. **Finalization**: Optionally, the `finalize` method transforms the | ||||||
| final combined accumulator into the desired output format. | ||||||
|
|
||||||
| Generic Type Parameters: | ||||||
| This class is parameterized by two type variables: | ||||||
|
|
||||||
| - ``AccumulatorType``: The type of the intermediate state (accumulator) used | ||||||
| during aggregation. This is what `aggregate_block` returns, what `combine` | ||||||
| takes as inputs and returns, and what `finalize` receives. For simple | ||||||
| aggregations like `Sum`, this might just be a numeric type. For more complex | ||||||
| aggregations like `Mean`, this could be a composite type like | ||||||
| ``List[Union[int, float]]`` representing ``[sum, count]``. | ||||||
|
|
||||||
| - ``AggOutputType``: The type of the final result after `finalize` is called. | ||||||
| This is what gets written to the output dataset. For `Sum`, this is the | ||||||
| same as the accumulator type (a number). For `Mean`, the accumulator is | ||||||
| ``[sum, count]`` but the output is a single ``float`` (the computed mean). | ||||||
|
|
||||||
| Examples of type parameterization in built-in aggregations:: | ||||||
|
|
||||||
| Count(AggregateFnV2[int, int]) # accumulator: int, output: int | ||||||
| Sum(AggregateFnV2[Union[int, float], ...]) # accumulator: number, output: number | ||||||
| Mean(AggregateFnV2[List[...], float]) # accumulator: [sum, count], output: float | ||||||
| Std(AggregateFnV2[List[...], float]) # accumulator: [M2, mean, count], output: float | ||||||
|
|
||||||
| Args: | ||||||
| name: The name of the aggregation. This will be used as the column name | ||||||
| in the output, e.g., "sum(my_col)". | ||||||
|
|
@@ -375,6 +397,69 @@ def combine(self, current_accumulator: int, new: int) -> int: | |||||
| return current_accumulator + new | ||||||
|
|
||||||
|
|
||||||
| @PublicAPI | ||||||
| class AsList(AggregateFnV2[List, List]): | ||||||
| """Listing aggregation combining all values within the group into a single | ||||||
| list element. | ||||||
|
|
||||||
| Example: | ||||||
|
|
||||||
| .. testcode:: | ||||||
| :skipif: True | ||||||
|
|
||||||
| # Skip testing b/c this example require proper ordering of the output | ||||||
| # to be robust and not flaky | ||||||
|
|
||||||
| import ray | ||||||
| from ray.data.aggregate import AsList | ||||||
|
|
||||||
| ds = ray.data.range(10) | ||||||
| # Schema: {'id': int64} | ||||||
| ds = ds.add_column("group_key", lambda x: x % 3) | ||||||
| # Schema: {'id': int64, 'group_key': int64} | ||||||
|
|
||||||
| # Listing all elements per group: | ||||||
| result = ds.groupby("group_key").aggregate(AsList(on="id")).take_all() | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add sort for testing?
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not sufficient -- we need to order the list too which makes this code example really clumsy (hence why i'm skipping testing it
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enabliling
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Preserve order won't guarantee sequence order either b/c it depends on the order of arrival of the shards into HashShuffleAggregator |
||||||
| # result: [{'group_key': 0, 'list(id)': [0, 3, 6, 9]}, | ||||||
| # {'group_key': 1, 'list(id)': [1, 4, 7]}, | ||||||
| # {'group_key': 2, 'list(id)': [2, 5, 8]} | ||||||
|
|
||||||
| Args: | ||||||
| on: The name of the column to collect values from. Must be provided. | ||||||
| alias_name: Optional name for the resulting column. | ||||||
| ignore_nulls: Whether to ignore null values when collecting. If `True`, | ||||||
| nulls are skipped. If `False` (default), nulls are included in the list. | ||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| on: str, | ||||||
| alias_name: Optional[str] = None, | ||||||
| ignore_nulls: bool = False, | ||||||
| ): | ||||||
| super().__init__( | ||||||
| alias_name if alias_name else f"list({on or ''})", | ||||||
| on=on, | ||||||
| ignore_nulls=ignore_nulls, | ||||||
| zero_factory=lambda: [], | ||||||
| ) | ||||||
|
|
||||||
| def aggregate_block(self, block: Block) -> AccumulatorType: | ||||||
| column_accessor = BlockColumnAccessor.for_column( | ||||||
| block[self.get_target_column()] | ||||||
| ) | ||||||
|
|
||||||
| if self._ignore_nulls: | ||||||
| column_accessor = BlockColumnAccessor.for_column(column_accessor.dropna()) | ||||||
|
|
||||||
| return column_accessor.to_pylist() | ||||||
|
|
||||||
| def combine( | ||||||
| self, current_accumulator: AccumulatorType, new: AccumulatorType | ||||||
| ) -> AccumulatorType: | ||||||
| return current_accumulator + new | ||||||
|
|
||||||
|
|
||||||
| @PublicAPI | ||||||
| class Sum(AggregateFnV2[Union[int, float], Union[int, float]]): | ||||||
| """Defines sum aggregation. | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.