-
Notifications
You must be signed in to change notification settings - Fork 167
feat(expr-ir): Implement Acero order_by
, hashjoin
for over
+ DataFrame.filter
#3173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: oh-nodes
Are you sure you want to change the base?
Conversation
# NOTE: See (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021) | ||
# The issue is `T` possibly being `Iterable` | ||
# Ignoring here still leaks the issue to the caller, where you need to annotate the base case | ||
@overload | ||
def flatten_hash_safe(iterable: Iterable[OneOrIterable[str]], /) -> Iterator[str]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an improvement over the previous version, but far from ideal.
Still doesn't resolve this case, and I'm not entirely sure why yet
narwhals/narwhals/_plan/compliant/column.py
Lines 49 to 60 in f77bb4c
@classmethod | |
def align( | |
cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]] | |
) -> Iterator[SeriesT]: | |
exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs)) | |
length = cls._length_required(exprs) | |
if length is None: | |
for e in exprs: | |
yield e.to_series() | |
else: | |
for e in exprs: | |
yield e.broadcast(length) |
def sort_by( | ||
by: OneOrIterable[str], | ||
*more_by: str, | ||
descending: OneOrIterable[bool] = False, | ||
nulls_last: bool = False, | ||
) -> Decl: | ||
return SortMultipleOptions.parse( | ||
descending=descending, nulls_last=nulls_last | ||
).to_arrow_acero(tuple(flatten_hash_safe((by, more_by)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As of feat(expr-ir): Impl acero.sort_by
, I still need to make use of this in a plan.
A good candidate might be in either/both of
over(order_by=...)
narwhals/narwhals/_plan/arrow/expr.py
Lines 328 to 350 in f77bb4c
def over_ordered( | |
self, node: ir.OrderedWindowExpr, frame: Frame, name: str | |
) -> Self | Scalar: | |
if node.partition_by: | |
msg = f"Need to implement `group_by`, `join` for:\n{node!r}" | |
raise NotImplementedError(msg) | |
# NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort` | |
sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by) | |
options = node.sort_options.to_multiple(len(node.order_by)) | |
idx_name = temp.column_name(frame) | |
sorted_context = frame.with_row_index(idx_name).sort(sort_by, options) | |
evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name) | |
if isinstance(evaluated, ArrowScalar): | |
# NOTE: We're already sorted, defer broadcasting to the outer context | |
# Wouldn't be suitable for partitions, but will be fine here | |
# - https://github.com/narwhals-dev/narwhals/pull/2528/commits/2ae42458cae91f4473e01270919815fcd7cb9667 | |
# - https://github.com/narwhals-dev/narwhals/pull/2528/commits/b8066c4c57d4b0b6c38d58a0f5de05eefc2cae70 | |
return self._with_native(evaluated.native, name) | |
indices = pc.sort_indices(sorted_context.get_column(idx_name).native) | |
height = len(sorted_context) | |
result = evaluated.broadcast(height).native.take(indices) | |
return self._with_native(result, name) |
is_{first,last}_distinct
narwhals/narwhals/_arrow/series.py
Lines 719 to 747 in 715be22
def is_first_distinct(self) -> Self: | |
import numpy as np # ignore-banned-import | |
row_number = pa.array(np.arange(len(self))) | |
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) | |
first_distinct_index = ( | |
pa.Table.from_arrays([self.native], names=[self.name]) | |
.append_column(col_token, row_number) | |
.group_by(self.name) | |
.aggregate([(col_token, "min")]) | |
.column(f"{col_token}_min") | |
) | |
return self._with_native(pc.is_in(row_number, first_distinct_index)) | |
def is_last_distinct(self) -> Self: | |
import numpy as np # ignore-banned-import | |
row_number = pa.array(np.arange(len(self))) | |
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) | |
last_distinct_index = ( | |
pa.Table.from_arrays([self.native], names=[self.name]) | |
.append_column(col_token, row_number) | |
.group_by(self.name) | |
.aggregate([(col_token, "max")]) | |
.column(f"{col_token}_max") | |
) | |
return self._with_native(pc.is_in(row_number, last_distinct_index)) |
Mostly following what is on `main` (so far)
Both are available at all levels, + `to_series` is implemented in term of `get_columns`
`is_{first,last}_distinct` are one of a few that fit that case
order_by
/sort_by
pairorder_by
, hashjoin
for over
narwhals/_plan/arrow/acero.py
Outdated
def join( | ||
left: pa.Table, | ||
right: pa.Table, | ||
how: JoinTypeSubset, | ||
left_on: OneOrIterable[str], | ||
right_on: OneOrIterable[str], | ||
suffix: str = "_right", | ||
*, | ||
coalesce_keys: bool = True, | ||
) -> Decl: | ||
"""Heavily based on [`pyarrow.acero._perform_join`]. | ||
[`pyarrow.acero._perform_join`]: https://github.com/apache/arrow/blob/f7320c9a40082639f9e0cf8b3075286e3fc6c0b9/python/pyarrow/acero.py#L82-L260 | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Investigate using non-table_source
nodes
- AFAICT,
hashjoin
should be able to accept things likeproject
- Defining the handling for
cross
joins usingacero
directly seems very achievable
narwhals/narwhals/_arrow/dataframe.py
Lines 400 to 422 in f4787d3
if how == "cross": plx = self.__narwhals_namespace__() key_token = generate_temporary_column_name( n_bytes=8, columns=[*self.columns, *other.columns] ) return self._with_native( self.with_columns( plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) ) .native.join( other.with_columns( plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) ).native, keys=key_token, right_keys=key_token, join_type="inner", right_suffix=suffix, ) .drop([key_token]) ) - There does need to be a new layer for tracking
Schema
changes- Which is needed for
with_columns
also - Generally, the responsibility for a future
LogicalPlan
- Which is needed for
- Starting to build up the join test suite - At some point, `"cross"` support will be needed
Everything else requires another feature to be implemented: - `DataFrame.filter` for semi, anti - `DataFrame.collect_schema` for suffix - `how="cross"` is just being defered currently (#3173 (comment))
50 lines! Even after all this refactoring 😔
tests/plan/join_test.py
Outdated
# NOTE: Maybe merge `semi`, `anti` into the same test which just inverts the predicate? | ||
@XFAIL_DATAFRAME_FILTER | ||
@pytest.mark.parametrize( | ||
("on", "predicate", "expected"), | ||
[ | ||
("a", (nwp.col("b") > 5), {"a": [2], "b": [6], "zor ro": [9]}), | ||
(["b"], (nwp.col("b") < 5), {"a": [1, 3], "b": [4, 4], "zor ro": [7, 8]}), | ||
(["a", "b"], (nwp.col("b") < 5), {"a": [1, 3], "b": [4, 4], "zor ro": [7, 8]}), | ||
], | ||
) | ||
def test_join_semi( | ||
on: On, predicate: nwp.Expr, expected: Data | ||
) -> None: # pragma: no cover | ||
data = {"a": [1, 3, 2], "b": [4, 4, 6], "zor ro": [7.0, 8.0, 9.0]} | ||
df = dataframe(data) | ||
other = df.filter(predicate) # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's interesting I got this far without DataFrame.filter
😂
I do have an acero
version, which operates over pc.Expression
though
narwhals/narwhals/_plan/arrow/acero.py
Lines 172 to 174 in b0c2a4d
def filter(*predicates: Expr, **constraints: IntoExpr) -> Decl: | |
expr = _parse_all_horizontal(predicates, constraints) | |
return Decl("filter", options=pac.FilterNodeOptions(expr)) |
So the missing link between those two is approximately this:
But with a fallback to an eager path like main
:
narwhals/narwhals/_arrow/dataframe.py
Lines 521 to 529 in 8ac061c
def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self: | |
if isinstance(predicate, list): | |
mask_native: Mask | ChunkedArrayAny = predicate | |
else: | |
# `[0]` is safe as the predicate's expression only returns a single column | |
mask_native = self._evaluate_into_exprs(predicate)[0].native | |
return self._with_native( | |
self.native.filter(mask_native), validate_column_names=False | |
) |
- Ideally these would be `str | Selector` or `Expr` containing only selections - But that isn't possible with the current typing - They *can* accept more - But it increases the complexity quite a lot for eager
Need similar logic for `DataFrame.filter`
Pretty sure on `main` that ignoring constraints is a bug
Quite handy that I did `Expr.filter` and `When` first 😄
order_by
, hashjoin
for over
order_by
, hashjoin
for over
+ DataFrame.filter
narwhals/_plan/dataframe.py
Outdated
if len(predicates) == 1 and not constraints: | ||
first = predicates[0] | ||
if is_list_of(first, bool): | ||
series = self._series.from_iterable( | ||
first, | ||
dtype=self.version.dtypes.Boolean(), | ||
backend=self.implementation, | ||
) | ||
elif is_series(first): | ||
series = first | ||
else: | ||
return super().filter(first) | ||
return self._with_compliant(self._compliant.filter(series._compliant)) | ||
non_mask = cast("tuple[OneOrIterable[IntoExprColumn],...]", predicates) | ||
return super().filter(*non_mask, **constraints) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If anyone see's this - don't copy this logic to solve (#3182)
I haven't handled it here either
Update
Should be clearer now with these failing tests (test: Add test_filter_mask_mixed
)
the exact text is allowed to change
Some basic cases to consider for #3182 If we decide against supporting them, then all can be converted into a `pytest.raises`
Really don't want this being part of the `ArrowDataFrame` constructor Viewing `join` as an edge case, whereas things like `select`, `with_columns` already handle duplicates during `prepare_projections`
def filter(self, predicate: NamedIR) -> Self: | ||
mask: pc.Expression | ChunkedArrayAny | ||
if not fn.is_series(predicate): | ||
resolved = Expr.from_named_ir(predicate, self) | ||
if isinstance(resolved, Expr): | ||
mask = resolved.broadcast(len(self)).native | ||
else: | ||
mask = acero.lit(resolved.native) | ||
resolved = Expr.from_named_ir(predicate, self) | ||
if isinstance(resolved, Expr): | ||
mask = resolved.broadcast(len(self)).native | ||
else: | ||
mask = predicate.native | ||
mask = acero.lit(resolved.native) | ||
return self._with_native(self.native.filter(mask)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice now 😄
Tracking
DataFrame.filter
silently ignores**constraints
when usinglist[bool]
#3182Related issues
Expr
IR #2572group_by
, utilizepyarrow.acero
#3143Description
Note
I've used the name
sort_by
for our wrapped oforder_by
.The node corresponds to
pa.Table.sort_by
, whereas the nameorder_by
doesn't appear anywhere else inpyarrow
Building out more
acero
parts to be able to support.over(*partition_by)