Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
cb470b4
refactor: Use `temp.column_name(s)` some more
dangotbanned Oct 1, 2025
23e9d43
fix(typing): Resolve some cases for `flatten_hash_safe`
dangotbanned Oct 1, 2025
f77bb4c
feat(expr-ir): Impl `acero.sort_by`
dangotbanned Oct 2, 2025
36ddce0
test: Port over `is_first_distinct` tests
dangotbanned Oct 2, 2025
0e49f57
chore: Add `Compliant{Expr,Scalar}.is_{first,last}_distinct`
dangotbanned Oct 2, 2025
a5f192c
test: Update to cover `is_last_distinct` as well
dangotbanned Oct 2, 2025
6a1b08a
feat(DRAFT): Initial `is_first_distinct` impl
dangotbanned Oct 2, 2025
1c026bf
test: Port over more cases
dangotbanned Oct 3, 2025
e7e8a04
refactor: Generalize `is_first_distinct` impl
dangotbanned Oct 3, 2025
2d46521
feat: Add `is_last_distinct`
dangotbanned Oct 3, 2025
cfb775d
refactor: Make both `is_*_distinct` methods, aliases
dangotbanned Oct 3, 2025
9db603b
feat: (Properly) add `get_column`, `to_series`
dangotbanned Oct 3, 2025
f8255d3
chore: Add `pc.is_in` wrapper
dangotbanned Oct 3, 2025
6fe2a0a
docs: Add detail to `FunctionFlags.LENGTH_PRESERVING`
dangotbanned Oct 3, 2025
938befb
test: More test porting
dangotbanned Oct 3, 2025
516f4a6
typo
dangotbanned Oct 3, 2025
ead4e62
feat(DRAFT): Some progress on `hashjoin` port
dangotbanned Oct 4, 2025
273bdcc
fix: Correctly pass down join keys
dangotbanned Oct 5, 2025
ce37617
test: Port over inner, left & clean up
dangotbanned Oct 5, 2025
18ef26a
test: Add `test_suffix`
dangotbanned Oct 5, 2025
94baf1e
test: Add `how="cross"` tests
dangotbanned Oct 5, 2025
733b45a
test: Add `how={"anti","semi"}` tests
dangotbanned Oct 5, 2025
ce321e0
test: replace `"antananarivo"`->`"a"`, `"bob"`->`"b"`
dangotbanned Oct 5, 2025
cc0d379
test: Port the other duplicate test
dangotbanned Oct 5, 2025
dd40e3a
test: Make all the xfails more visible
dangotbanned Oct 5, 2025
d1a1785
feat(DRAFT): Initial acero cross-join impl
dangotbanned Oct 5, 2025
77e55b3
refactor: Only expose `acero.join_tables`
dangotbanned Oct 5, 2025
8f7d2f3
chore: Start factoring-out `Table` dependency
dangotbanned Oct 5, 2025
b0c2a4d
Merge branch 'oh-nodes' into expr-ir/acero-order-by
dangotbanned Oct 6, 2025
d42f5de
refactor(typing): Use `IntoExprColumn` some more
dangotbanned Oct 6, 2025
b8a58c1
refactor: Split up `_parse_sort_by`
dangotbanned Oct 6, 2025
05c63fd
Make a start on `DataFrame.filter`
dangotbanned Oct 6, 2025
025213d
fill out slightly more `filter`
dangotbanned Oct 6, 2025
3e94449
get typing working again (kinda)
dangotbanned Oct 6, 2025
a611bc9
feat(DRAFT): Support `filter(list[bool])`
dangotbanned Oct 6, 2025
d514ad0
feat: Support single `Series` as well
dangotbanned Oct 6, 2025
d452920
test: Use `parametrize`
dangotbanned Oct 6, 2025
4c7c23d
feat: Add predicate expansion
dangotbanned Oct 6, 2025
2ebca30
feat(expr-ir): Full `DataFrame.filter` support
dangotbanned Oct 6, 2025
1b66786
test: Merge the anti/semi tests
dangotbanned Oct 6, 2025
fd38911
test: parametrize exception messages
dangotbanned Oct 6, 2025
3537cac
test: relax more error messages
dangotbanned Oct 6, 2025
b5ef86b
typo
dangotbanned Oct 7, 2025
8433b2d
test: Add `test_filter_mask_mixed`
dangotbanned Oct 7, 2025
7668abb
fix: Raise on duplicate column names
dangotbanned Oct 7, 2025
3ca43d1
cov
dangotbanned Oct 7, 2025
0f06479
perf: Avoid multiple collections during cross join
dangotbanned Oct 7, 2025
7e9ee74
test: Stop repeating the same data so many times
dangotbanned Oct 7, 2025
1523dbb
test: Add some cases from polars
dangotbanned Oct 8, 2025
a479f32
fix: typing mypy
dangotbanned Oct 8, 2025
8e840e0
feat(expr-ir): Full-er `DataFrame.filter` support
dangotbanned Oct 8, 2025
af26916
refactor: Simplify the `NonCrossJoinStrategy` split
dangotbanned Oct 8, 2025
6aaf75d
test: Convert raising test into a conformance test
dangotbanned Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions narwhals/_plan/_expr_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,7 @@ def is_column(self, *, allow_aliasing: bool = False) -> bool:

ir = self.expr
return isinstance(ir, Column) and ((self.name == ir.name) or allow_aliasing)


def named_ir(name: str, expr: ExprIRT, /) -> NamedIR[ExprIRT]:
return NamedIR(expr=expr, name=name)
114 changes: 107 additions & 7 deletions narwhals/_plan/arrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
import pyarrow.compute as pc # ignore-banned-import
from pyarrow.acero import Declaration as Decl

from narwhals._plan.common import flatten_hash_safe
from narwhals._plan.options import SortMultipleOptions
from narwhals._plan.typing import OneOrSeq
from narwhals.typing import SingleColSelector
from narwhals.typing import JoinStrategy, SingleColSelector

if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterable, Iterator
from collections.abc import Callable, Collection, Iterable, Iterator, Mapping

from typing_extensions import TypeAlias

Expand All @@ -38,7 +40,7 @@
Aggregation as _Aggregation,
)
from narwhals._plan.arrow.group_by import AggSpec
from narwhals._plan.arrow.typing import NullPlacement
from narwhals._plan.arrow.typing import JoinTypeSubset, NullPlacement
from narwhals._plan.typing import OneOrIterable, Order, Seq
from narwhals.typing import NonNestedLiteral

Expand All @@ -64,6 +66,14 @@
lit = cast("Callable[[NonNestedLiteral], Expr]", pc.scalar)
"""Alias for `pyarrow.compute.scalar`."""

_HOW_JOIN: Mapping[JoinStrategy, JoinTypeSubset] = {
"inner": "inner",
"left": "left outer",
"full": "full outer",
"anti": "left anti",
"semi": "left semi",
}


# NOTE: ATOW there are 304 valid function names, 46 can be used for some kind of agg
# Due to expr expansion, it is very likely that we have repeat runs
Expand Down Expand Up @@ -189,10 +199,81 @@ def _order_by(
return Decl("order_by", pac.OrderByNodeOptions(keys, null_placement=null_placement))


# TODO @dangotbanned: Utilize `SortMultipleOptions.to_arrow_acero`
def sort_by(*args: Any, **kwds: Any) -> Decl:
msg = "Should convert from polars args -> use `_order_by"
raise NotImplementedError(msg)
def sort_by(
by: OneOrIterable[str],
*more_by: str,
descending: OneOrIterable[bool] = False,
nulls_last: bool = False,
) -> Decl:
return SortMultipleOptions.parse(
descending=descending, nulls_last=nulls_last
).to_arrow_acero(tuple(flatten_hash_safe((by, more_by))))
Comment on lines +261 to +269
Copy link
Member Author

@dangotbanned dangotbanned Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of feat(expr-ir): Impl acero.sort_by, I still need to make use of this in a plan.

A good candidate might be in either/both of

over(order_by=...)

def over_ordered(
self, node: ir.OrderedWindowExpr, frame: Frame, name: str
) -> Self | Scalar:
if node.partition_by:
msg = f"Need to implement `group_by`, `join` for:\n{node!r}"
raise NotImplementedError(msg)
# NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort`
sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by)
options = node.sort_options.to_multiple(len(node.order_by))
idx_name = temp.column_name(frame)
sorted_context = frame.with_row_index(idx_name).sort(sort_by, options)
evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name)
if isinstance(evaluated, ArrowScalar):
# NOTE: We're already sorted, defer broadcasting to the outer context
# Wouldn't be suitable for partitions, but will be fine here
# - https://github.com/narwhals-dev/narwhals/pull/2528/commits/2ae42458cae91f4473e01270919815fcd7cb9667
# - https://github.com/narwhals-dev/narwhals/pull/2528/commits/b8066c4c57d4b0b6c38d58a0f5de05eefc2cae70
return self._with_native(evaluated.native, name)
indices = pc.sort_indices(sorted_context.get_column(idx_name).native)
height = len(sorted_context)
result = evaluated.broadcast(height).native.take(indices)
return self._with_native(result, name)

is_{first,last}_distinct

def is_first_distinct(self) -> Self:
import numpy as np # ignore-banned-import
row_number = pa.array(np.arange(len(self)))
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
first_distinct_index = (
pa.Table.from_arrays([self.native], names=[self.name])
.append_column(col_token, row_number)
.group_by(self.name)
.aggregate([(col_token, "min")])
.column(f"{col_token}_min")
)
return self._with_native(pc.is_in(row_number, first_distinct_index))
def is_last_distinct(self) -> Self:
import numpy as np # ignore-banned-import
row_number = pa.array(np.arange(len(self)))
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
last_distinct_index = (
pa.Table.from_arrays([self.native], names=[self.name])
.append_column(col_token, row_number)
.group_by(self.name)
.aggregate([(col_token, "max")])
.column(f"{col_token}_max")
)
return self._with_native(pc.is_in(row_number, last_distinct_index))



def join(
left: pa.Table,
right: pa.Table,
how: JoinTypeSubset,
left_on: OneOrIterable[str],
right_on: OneOrIterable[str],
suffix: str = "_right",
*,
coalesce_keys: bool = True,
) -> Decl:
"""Heavily based on [`pyarrow.acero._perform_join`].

[`pyarrow.acero._perform_join`]: https://github.com/apache/arrow/blob/f7320c9a40082639f9e0cf8b3075286e3fc6c0b9/python/pyarrow/acero.py#L82-L260
"""
left_on = [left_on] if isinstance(left_on, str) else list(left_on)
right_on = [right_on] if isinstance(right_on, str) else list(right_on)

# polars full join does not coalesce keys,
coalesce_keys = coalesce_keys and (how != "full outer")
if not coalesce_keys:
opts = _join_options(how, left_on, right_on, suffix=suffix)
return _hashjoin(left, right, opts)

# By default expose all columns on both left and right table
left_names = left.schema.names
right_names = right.schema.names

if how in {"left semi", "left anti"}:
right_names = []
elif how in {"inner", "left outer"}:
right_names = [name for name in right_names if name not in right_on]
opts = _join_options(
how,
left_on,
right_on,
suffix=suffix,
left_output=left_names,
right_output=right_names,
)
return _hashjoin(left, right, opts)


def _join_options(
how: JoinTypeSubset,
left_on: str | list[str],
right_on: str | list[str],
*,
suffix: str = "_right",
left_output: Iterable[str] | None = None,
right_output: Iterable[str] | None = None,
) -> pac.HashJoinNodeOptions:
tp: Incomplete = pac.HashJoinNodeOptions
kwds = {
"left_output": left_output,
"right_output": right_output,
"output_suffix_for_right": suffix,
}
return tp(how, left_on, right_on, **kwds) # type: ignore[no-any-return]


def _hashjoin(
left: pa.Table, right: pa.Table, /, options: pac.HashJoinNodeOptions
) -> Decl:
return Decl("hashjoin", options, [table_source(left), table_source(right)])


def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table:
Expand Down Expand Up @@ -251,3 +332,22 @@ def select_names_table(
native: pa.Table, column_names: OneOrIterable[str], *more_names: str
) -> pa.Table:
return collect(table_source(native), select_names(column_names, *more_names))


def join_tables(
left: pa.Table,
right: pa.Table,
how: JoinStrategy,
left_on: OneOrIterable[str] | None,
right_on: OneOrIterable[str] | None = (),
suffix: str = "_right",
*,
coalesce_keys: bool = True,
) -> pa.Table:
join_type = _HOW_JOIN[how]
left_on = left_on or ()
right_on = right_on or left_on
decl = join(
left, right, join_type, left_on, right_on, suffix, coalesce_keys=coalesce_keys
)
return collect(decl)
21 changes: 19 additions & 2 deletions narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pyarrow.compute as pc # ignore-banned-import

from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._plan.arrow import functions as fn
from narwhals._plan.arrow import acero, functions as fn
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.compliant.dataframe import EagerDataFrame
Expand All @@ -31,7 +31,7 @@
from narwhals._plan.options import SortMultipleOptions
from narwhals._plan.typing import Seq
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema
from narwhals.typing import IntoSchema, JoinStrategy


class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]):
Expand Down Expand Up @@ -144,3 +144,20 @@ def select_names(self, *column_names: str) -> Self:
def row(self, index: int) -> tuple[Any, ...]:
row = self.native.slice(index, 1)
return tuple(chain.from_iterable(row.to_pydict().values()))

def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str = "_right",
) -> Self:
if how == "cross":
msg = f"join(how={how!r})"
raise NotImplementedError(msg)
result = acero.join_tables(
self.native, other.native, how, left_on, right_on, suffix=suffix
)
return self._with_native(result)
45 changes: 32 additions & 13 deletions narwhals/_plan/arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,17 @@
import pyarrow.compute as pc # ignore-banned-import

from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._plan import expressions as ir
from narwhals._plan.arrow import functions as fn
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co
from narwhals._plan.common import temp
from narwhals._plan.compliant.column import ExprDispatch
from narwhals._plan.compliant.expr import EagerExpr
from narwhals._plan.compliant.scalar import EagerScalar
from narwhals._plan.compliant.typing import namespace
from narwhals._plan.expressions import NamedIR
from narwhals._utils import (
Implementation,
Version,
_StoresNative,
generate_temporary_column_name,
not_implemented,
)
from narwhals._utils import Implementation, Version, _StoresNative, not_implemented
from narwhals.exceptions import InvalidOperationError, ShapeError

if TYPE_CHECKING:
Expand All @@ -29,7 +25,6 @@
from typing_extensions import Self, TypeAlias

from narwhals._arrow.typing import ChunkedArrayAny, Incomplete
from narwhals._plan import expressions as ir
from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame
from narwhals._plan.arrow.namespace import ArrowNamespace
from narwhals._plan.expressions.aggregation import (
Expand All @@ -53,6 +48,8 @@
All,
IsBetween,
IsFinite,
IsFirstDistinct,
IsLastDistinct,
IsNan,
IsNull,
Not,
Expand Down Expand Up @@ -198,6 +195,9 @@ def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Sel
return ArrowScalar.from_native(result, name, version=self.version)
return self.from_native(result, name or self.name, self.version)

# NOTE: I'm not sure what I meant by
# > "isn't natively supported on `ChunkedArray`"
# Was that supposed to say "is only supported on `ChunkedArray`"?
def _dispatch_expr(self, node: ir.ExprIR, frame: Frame, name: str) -> Series:
"""Use instead of `_dispatch` *iff* an operation isn't natively supported on `ChunkedArray`.

Expand Down Expand Up @@ -231,10 +231,8 @@ def sort(self, node: ir.Sort, frame: Frame, name: str) -> Expr:

def sort_by(self, node: ir.SortBy, frame: Frame, name: str) -> Expr:
series = self._dispatch_expr(node.expr, frame, name)
by = (
self._dispatch_expr(e, frame, f"<TEMP>_{idx}")
for idx, e in enumerate(node.by)
)
it_names = temp.column_names(frame)
by = (self._dispatch_expr(e, frame, nm) for e, nm in zip(node.by, it_names))
df = namespace(self)._concat_horizontal((series, *by))
names = df.columns[1:]
indices = pc.sort_indices(df.native, options=node.options.to_arrow(names))
Expand Down Expand Up @@ -342,7 +340,7 @@ def over_ordered(
# NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort`
sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by)
options = node.sort_options.to_multiple(len(node.order_by))
idx_name = generate_temporary_column_name(8, frame.columns)
idx_name = temp.column_name(frame)
sorted_context = frame.with_row_index(idx_name).sort(sort_by, options)
evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name)
if isinstance(evaluated, ArrowScalar):
Expand Down Expand Up @@ -374,6 +372,27 @@ def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self:
def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self:
raise NotImplementedError

def _is_first_last_distinct(
self,
node: FunctionExpr[IsFirstDistinct | IsLastDistinct],
frame: Frame,
name: str,
) -> Self:
idx_name = temp.column_name([name])
expr_ir = fn.IS_FIRST_LAST_DISTINCT[type(node.function)](idx_name)
series = self._dispatch_expr(node.input[0], frame, name)
df = series.to_frame().with_row_index(idx_name)
distinct_index = (
df.group_by_names((name,))
.agg((ir.named_ir(idx_name, expr_ir),))
.get_column(idx_name)
.native
)
return self._with_native(fn.is_in(df.to_series().native, distinct_index), name)

is_first_distinct = _is_first_last_distinct
is_last_distinct = _is_first_last_distinct


class ArrowScalar(
_ArrowDispatch["ArrowScalar"],
Expand Down
Loading
Loading