Skip to content
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
cb470b4
refactor: Use `temp.column_name(s)` some more
dangotbanned Oct 1, 2025
23e9d43
fix(typing): Resolve some cases for `flatten_hash_safe`
dangotbanned Oct 1, 2025
f77bb4c
feat(expr-ir): Impl `acero.sort_by`
dangotbanned Oct 2, 2025
36ddce0
test: Port over `is_first_distinct` tests
dangotbanned Oct 2, 2025
0e49f57
chore: Add `Compliant{Expr,Scalar}.is_{first,last}_distinct`
dangotbanned Oct 2, 2025
a5f192c
test: Update to cover `is_last_distinct` as well
dangotbanned Oct 2, 2025
6a1b08a
feat(DRAFT): Initial `is_first_distinct` impl
dangotbanned Oct 2, 2025
1c026bf
test: Port over more cases
dangotbanned Oct 3, 2025
e7e8a04
refactor: Generalize `is_first_distinct` impl
dangotbanned Oct 3, 2025
2d46521
feat: Add `is_last_distinct`
dangotbanned Oct 3, 2025
cfb775d
refactor: Make both `is_*_distinct` methods, aliases
dangotbanned Oct 3, 2025
9db603b
feat: (Properly) add `get_column`, `to_series`
dangotbanned Oct 3, 2025
f8255d3
chore: Add `pc.is_in` wrapper
dangotbanned Oct 3, 2025
6fe2a0a
docs: Add detail to `FunctionFlags.LENGTH_PRESERVING`
dangotbanned Oct 3, 2025
938befb
test: More test porting
dangotbanned Oct 3, 2025
516f4a6
typo
dangotbanned Oct 3, 2025
ead4e62
feat(DRAFT): Some progress on `hashjoin` port
dangotbanned Oct 4, 2025
273bdcc
fix: Correctly pass down join keys
dangotbanned Oct 5, 2025
ce37617
test: Port over inner, left & clean up
dangotbanned Oct 5, 2025
18ef26a
test: Add `test_suffix`
dangotbanned Oct 5, 2025
94baf1e
test: Add `how="cross"` tests
dangotbanned Oct 5, 2025
733b45a
test: Add `how={"anti","semi"}` tests
dangotbanned Oct 5, 2025
ce321e0
test: replace `"antananarivo"`->`"a"`, `"bob"`->`"b"`
dangotbanned Oct 5, 2025
cc0d379
test: Port the other duplicate test
dangotbanned Oct 5, 2025
dd40e3a
test: Make all the xfails more visible
dangotbanned Oct 5, 2025
d1a1785
feat(DRAFT): Initial acero cross-join impl
dangotbanned Oct 5, 2025
77e55b3
refactor: Only expose `acero.join_tables`
dangotbanned Oct 5, 2025
8f7d2f3
chore: Start factoring-out `Table` dependency
dangotbanned Oct 5, 2025
b0c2a4d
Merge branch 'oh-nodes' into expr-ir/acero-order-by
dangotbanned Oct 6, 2025
d42f5de
refactor(typing): Use `IntoExprColumn` some more
dangotbanned Oct 6, 2025
b8a58c1
refactor: Split up `_parse_sort_by`
dangotbanned Oct 6, 2025
05c63fd
Make a start on `DataFrame.filter`
dangotbanned Oct 6, 2025
025213d
fill out slightly more `filter`
dangotbanned Oct 6, 2025
3e94449
get typing working again (kinda)
dangotbanned Oct 6, 2025
a611bc9
feat(DRAFT): Support `filter(list[bool])`
dangotbanned Oct 6, 2025
d514ad0
feat: Support single `Series` as well
dangotbanned Oct 6, 2025
d452920
test: Use `parametrize`
dangotbanned Oct 6, 2025
4c7c23d
feat: Add predicate expansion
dangotbanned Oct 6, 2025
2ebca30
feat(expr-ir): Full `DataFrame.filter` support
dangotbanned Oct 6, 2025
1b66786
test: Merge the anti/semi tests
dangotbanned Oct 6, 2025
fd38911
test: parametrize exception messages
dangotbanned Oct 6, 2025
3537cac
test: relax more error messages
dangotbanned Oct 6, 2025
b5ef86b
typo
dangotbanned Oct 7, 2025
8433b2d
test: Add `test_filter_mask_mixed`
dangotbanned Oct 7, 2025
7668abb
fix: Raise on duplicate column names
dangotbanned Oct 7, 2025
3ca43d1
cov
dangotbanned Oct 7, 2025
0f06479
perf: Avoid multiple collections during cross join
dangotbanned Oct 7, 2025
7e9ee74
test: Stop repeating the same data so many times
dangotbanned Oct 7, 2025
1523dbb
test: Add some cases from polars
dangotbanned Oct 8, 2025
a479f32
fix: typing mypy
dangotbanned Oct 8, 2025
8e840e0
feat(expr-ir): Full-er `DataFrame.filter` support
dangotbanned Oct 8, 2025
af26916
refactor: Simplify the `NonCrossJoinStrategy` split
dangotbanned Oct 8, 2025
6aaf75d
test: Convert raising test into a conformance test
dangotbanned Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions narwhals/_plan/_expr_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,7 @@ def is_column(self, *, allow_aliasing: bool = False) -> bool:

ir = self.expr
return isinstance(ir, Column) and ((self.name == ir.name) or allow_aliasing)


def named_ir(name: str, expr: ExprIRT, /) -> NamedIR[ExprIRT]:
return NamedIR(expr=expr, name=name)
20 changes: 20 additions & 0 deletions narwhals/_plan/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_iterable_polars_error,
)
from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -124,6 +125,25 @@ def parse_predicates_constraints_into_expr_ir(
return _combine_predicates(all_predicates)


def parse_sort_by_into_seq_of_expr_ir(
by: OneOrIterable[IntoExprColumn] = (), *more_by: IntoExprColumn
) -> Seq[ExprIR]:
"""Parse `DataFrame.sort` and `Expr.sort_by` keys into a flat sequence of `ExprIR` nodes."""
return tuple(_parse_sort_by_into_iter_expr_ir(by, more_by))


# TODO @dangotbanned: Review the rejection predicate
# It doesn't cover all length-changing expressions, only aggregations/literals
def _parse_sort_by_into_iter_expr_ir(
by: OneOrIterable[IntoExprColumn], more_by: Iterable[IntoExprColumn]
) -> Iterator[ExprIR]:
for e in _parse_into_iter_expr_ir(by, *more_by):
if e.is_scalar:
msg = f"All expressions sort keys must preserve length, but got:\n{e!r}"
raise InvalidOperationError(msg)
yield e


def _parse_into_iter_expr_ir(
first_input: OneOrIterable[IntoExpr], *more_inputs: IntoExpr, **named_inputs: IntoExpr
) -> Iterator[ExprIR]:
Expand Down
207 changes: 195 additions & 12 deletions narwhals/_plan/arrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,40 @@
import pyarrow.compute as pc # ignore-banned-import
from pyarrow.acero import Declaration as Decl

from narwhals._plan.common import ensure_list_str, flatten_hash_safe, temp
from narwhals._plan.options import SortMultipleOptions
from narwhals._plan.typing import OneOrSeq
from narwhals.typing import SingleColSelector
from narwhals.typing import JoinStrategy, SingleColSelector

if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterable, Iterator
from collections.abc import (
Callable,
Collection,
Iterable,
Iterator,
Mapping,
Sequence,
)

from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, TypeIs

from narwhals._arrow.typing import ( # type: ignore[attr-defined]
AggregateOptions as _AggregateOptions,
Aggregation as _Aggregation,
)
from narwhals._plan.arrow.group_by import AggSpec
from narwhals._plan.arrow.typing import NullPlacement
from narwhals._plan.arrow.typing import (
ArrowAny,
JoinTypeSubset,
NullPlacement,
ScalarAny,
)
from narwhals._plan.typing import OneOrIterable, Order, Seq
from narwhals.typing import NonNestedLiteral

Incomplete: TypeAlias = Any
Expr: TypeAlias = pc.Expression
IntoExpr: TypeAlias = "Expr | NonNestedLiteral"
IntoExpr: TypeAlias = "Expr | NonNestedLiteral | ScalarAny"
Field: TypeAlias = Union[Expr, SingleColSelector]
"""Anything that passes as a single item in [`_compute._ensure_field_ref`].

Expand All @@ -57,12 +71,28 @@
Opts: TypeAlias = "AggregateOptions | None"
OutputName: TypeAlias = str

IntoDecl: TypeAlias = Union[pa.Table, Decl]
"""An in-memory table, or a plan that began with one."""

_THREAD_UNSAFE: Final = frozenset[Aggregation](
("hash_first", "hash_last", "first", "last")
)
col = pc.field
lit = cast("Callable[[NonNestedLiteral], Expr]", pc.scalar)
"""Alias for `pyarrow.compute.scalar`."""
lit = cast("Callable[[NonNestedLiteral | ScalarAny], Expr]", pc.scalar)
"""Alias for `pyarrow.compute.scalar`.

Extends the signature from `bool | float | str`.

See https://github.com/apache/arrow/pull/47609#discussion_r2392499842
"""

_HOW_JOIN: Mapping[JoinStrategy, JoinTypeSubset] = {
"inner": "inner",
"left": "left outer",
"full": "full outer",
"anti": "left anti",
"semi": "left semi",
}


# NOTE: ATOW there are 304 valid function names, 46 can be used for some kind of agg
Expand All @@ -72,8 +102,17 @@ def can_thread(function_name: str, /) -> bool:
return function_name not in _THREAD_UNSAFE


def cols_iter(names: Iterable[str], /) -> Iterator[Expr]:
for name in names:
yield col(name)


def _is_expr(obj: Any) -> TypeIs[pc.Expression]:
return isinstance(obj, pc.Expression)


def _parse_into_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr:
if isinstance(into, pc.Expression):
if _is_expr(into):
return into
if isinstance(into, str) and not str_as_lit:
return col(into)
Expand Down Expand Up @@ -177,6 +216,35 @@ def project(**named_exprs: IntoExpr) -> Decl:
return _project(names=named_exprs.keys(), exprs=exprs)


def _add_column(
native: pa.Table, index: int, name: str, values: IntoExpr | ArrowAny
) -> pa.Table:
if isinstance(values, (pa.ChunkedArray, pa.Array)):
return native.add_column(index, name, values)
column = values if _is_expr(values) else lit(values)
schema = native.schema
schema_names = schema.names
if index == 0:
names: Sequence[str] = (name, *schema_names)
exprs = (column, *cols_iter(schema_names))
elif index == native.num_columns:
names = (*schema_names, name)
exprs = (*cols_iter(schema_names), column)
else:
schema_names.insert(index, name)
names = schema_names
exprs = tuple(_parse_into_iter_expr(nm if nm != name else column for nm in names))
return collect(table_source(native), _project(exprs, names))


def append_column(native: pa.Table, name: str, values: IntoExpr | ArrowAny) -> pa.Table:
return _add_column(native, native.num_columns, name, values)


def prepend_column(native: pa.Table, name: str, values: IntoExpr | ArrowAny) -> pa.Table:
return _add_column(native, 0, name, values)


def _order_by(
sort_keys: Iterable[tuple[str, Order]] = (),
*,
Expand All @@ -189,10 +257,84 @@ def _order_by(
return Decl("order_by", pac.OrderByNodeOptions(keys, null_placement=null_placement))


# TODO @dangotbanned: Utilize `SortMultipleOptions.to_arrow_acero`
def sort_by(*args: Any, **kwds: Any) -> Decl:
msg = "Should convert from polars args -> use `_order_by"
raise NotImplementedError(msg)
def sort_by(
by: OneOrIterable[str],
*more_by: str,
descending: OneOrIterable[bool] = False,
nulls_last: bool = False,
) -> Decl:
return SortMultipleOptions.parse(
descending=descending, nulls_last=nulls_last
).to_arrow_acero(tuple(flatten_hash_safe((by, more_by))))
Comment on lines +261 to +269
Copy link
Member Author

@dangotbanned dangotbanned Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of feat(expr-ir): Impl acero.sort_by, I still need to make use of this in a plan.

A good candidate might be in either/both of

over(order_by=...)

def over_ordered(
self, node: ir.OrderedWindowExpr, frame: Frame, name: str
) -> Self | Scalar:
if node.partition_by:
msg = f"Need to implement `group_by`, `join` for:\n{node!r}"
raise NotImplementedError(msg)
# NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort`
sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by)
options = node.sort_options.to_multiple(len(node.order_by))
idx_name = temp.column_name(frame)
sorted_context = frame.with_row_index(idx_name).sort(sort_by, options)
evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name)
if isinstance(evaluated, ArrowScalar):
# NOTE: We're already sorted, defer broadcasting to the outer context
# Wouldn't be suitable for partitions, but will be fine here
# - https://github.com/narwhals-dev/narwhals/pull/2528/commits/2ae42458cae91f4473e01270919815fcd7cb9667
# - https://github.com/narwhals-dev/narwhals/pull/2528/commits/b8066c4c57d4b0b6c38d58a0f5de05eefc2cae70
return self._with_native(evaluated.native, name)
indices = pc.sort_indices(sorted_context.get_column(idx_name).native)
height = len(sorted_context)
result = evaluated.broadcast(height).native.take(indices)
return self._with_native(result, name)

is_{first,last}_distinct

def is_first_distinct(self) -> Self:
import numpy as np # ignore-banned-import
row_number = pa.array(np.arange(len(self)))
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
first_distinct_index = (
pa.Table.from_arrays([self.native], names=[self.name])
.append_column(col_token, row_number)
.group_by(self.name)
.aggregate([(col_token, "min")])
.column(f"{col_token}_min")
)
return self._with_native(pc.is_in(row_number, first_distinct_index))
def is_last_distinct(self) -> Self:
import numpy as np # ignore-banned-import
row_number = pa.array(np.arange(len(self)))
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
last_distinct_index = (
pa.Table.from_arrays([self.native], names=[self.name])
.append_column(col_token, row_number)
.group_by(self.name)
.aggregate([(col_token, "max")])
.column(f"{col_token}_max")
)
return self._with_native(pc.is_in(row_number, last_distinct_index))



# TODO @dangotbanned: Add a variant that doesn't depend on 2x tables
# - `_join_options`: just needs iterables (currently sourced from `schema.names``)
# -`_hashjoin`: can now accept `Declaration`s in either case
def join(
left: pa.Table,
right: pa.Table,
how: JoinTypeSubset,
left_on: OneOrIterable[str],
right_on: OneOrIterable[str],
suffix: str = "_right",
*,
coalesce_keys: bool = True,
) -> Decl:
"""Heavily based on [`pyarrow.acero._perform_join`].

[`pyarrow.acero._perform_join`]: https://github.com/apache/arrow/blob/f7320c9a40082639f9e0cf8b3075286e3fc6c0b9/python/pyarrow/acero.py#L82-L260
"""
opts = _join_options(
how,
left_on,
right_on,
suffix,
left.schema.names,
right.schema.names,
coalesce_keys=coalesce_keys,
)
return _hashjoin(left, right, opts)


def _join_options(
how: JoinTypeSubset,
left_on: OneOrIterable[str],
right_on: OneOrIterable[str],
suffix: str = "_right",
left_names: Iterable[str] | None = None,
right_names: Iterable[str] = (),
*,
coalesce_keys: bool = True,
) -> pac.HashJoinNodeOptions:
right_on = ensure_list_str(right_on)
rhs_names: Iterable[str] | None = None
# polars full join does not coalesce keys
if not (coalesce_keys and (how != "full outer")):
lhs_names = None
else:
lhs_names = left_names
if how in {"inner", "left outer"}:
rhs_names = (name for name in right_names if name not in right_on)
tp: Incomplete = pac.HashJoinNodeOptions
return tp( # type: ignore[no-any-return]
join_type=how,
left_keys=ensure_list_str(left_on),
right_keys=right_on,
left_output=lhs_names,
right_output=rhs_names,
output_suffix_for_right=suffix,
)


def _into_decl(source: IntoDecl, /) -> Decl:
return source if not isinstance(source, pa.Table) else table_source(source)


def _hashjoin(
left: IntoDecl, right: IntoDecl, /, options: pac.HashJoinNodeOptions
) -> Decl:
return Decl("hashjoin", options, [_into_decl(left), _into_decl(right)])


def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table:
Expand Down Expand Up @@ -251,3 +393,44 @@ def select_names_table(
native: pa.Table, column_names: OneOrIterable[str], *more_names: str
) -> pa.Table:
return collect(table_source(native), select_names(column_names, *more_names))


def join_tables(
left: pa.Table,
right: pa.Table,
how: JoinStrategy,
left_on: OneOrIterable[str] | None = (),
right_on: OneOrIterable[str] | None = (),
suffix: str = "_right",
*,
coalesce_keys: bool = True,
) -> pa.Table:
if how == "cross":
return _join_cross_tables(left, right, suffix)
join_type = _HOW_JOIN[how]
left_on = left_on or ()
right_on = right_on or left_on
decl = join(
left, right, join_type, left_on, right_on, suffix, coalesce_keys=coalesce_keys
)
return collect(decl)


# TODO @dangotbanned: Very rough start to get tests passing
# - Decouple from `pa.Table` & collecting 3 times
# - Reuse the plan from `_add_column`
# - Write some more specialized parsers for
# [x] column names
# [ ] indices?
def _join_cross_tables(
left: pa.Table, right: pa.Table, suffix: str = "_right"
) -> pa.Table:
key_token = temp.column_name(chain(left.column_names, right.column_names))
result = join_tables(
prepend_column(left, key_token, 0),
prepend_column(right, key_token, 0),
how="inner",
left_on=key_token,
suffix=suffix,
)
return result.remove_column(0)
36 changes: 31 additions & 5 deletions narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,33 @@
import pyarrow.compute as pc # ignore-banned-import

from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._plan.arrow import functions as fn
from narwhals._plan.arrow import acero, functions as fn
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.compliant.dataframe import EagerDataFrame
from narwhals._plan.compliant.typing import namespace
from narwhals._plan.expressions import NamedIR
from narwhals._plan.typing import Seq
from narwhals._utils import Version, parse_columns_to_drop
from narwhals._utils import Implementation, Version, parse_columns_to_drop
from narwhals.schema import Schema

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence

from typing_extensions import Self

from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
from narwhals._arrow.typing import ChunkedArrayAny
from narwhals._plan.arrow.namespace import ArrowNamespace
from narwhals._plan.expressions import ExprIR, NamedIR
from narwhals._plan.options import SortMultipleOptions
from narwhals._plan.typing import Seq
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema
from narwhals.typing import IntoSchema, JoinStrategy


class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]):
implementation = Implementation.PYARROW
_native: pa.Table
_version: Version

Expand Down Expand Up @@ -144,3 +145,28 @@ def select_names(self, *column_names: str) -> Self:
def row(self, index: int) -> tuple[Any, ...]:
row = self.native.slice(index, 1)
return tuple(chain.from_iterable(row.to_pydict().values()))

def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str = "_right",
) -> Self:
left, right = self.native, other.native
result = acero.join_tables(left, right, how, left_on, right_on, suffix=suffix)
return self._with_native(result)

def filter(self, predicate: NamedIR | Series) -> Self:
mask: pc.Expression | ChunkedArrayAny
if not fn.is_series(predicate):
resolved = Expr.from_named_ir(predicate, self)
if isinstance(resolved, Expr):
mask = resolved.broadcast(len(self)).native
else:
mask = acero.lit(resolved.native)
else:
mask = predicate.native
return self._with_native(self.native.filter(mask))
Loading
Loading