Skip to content

Commit 4feae2f

Browse files
committed
feat: Make expressions printable, rewrite internals
1 parent 63c5022 commit 4feae2f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+1700
-2065
lines changed

docs/how_it_works.md

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ pn = PandasLikeNamespace(
7676
implementation=Implementation.PANDAS,
7777
version=Version.MAIN,
7878
)
79-
print(nw.col("a")._to_compliant_expr(pn))
79+
print(nw.col("a")(pn))
8080
```
81+
8182
The result from the last line above is the same as we'd get from `pn.col('a')`, and it's
8283
a `narwhals._pandas_like.expr.PandasLikeExpr` object, which we'll call `PandasLikeExpr` for
8384
short.
@@ -177,7 +178,7 @@ The way you access the Narwhals-compliant wrapper depends on the object:
177178

178179
- `narwhals.DataFrame` and `narwhals.LazyFrame`: use the `._compliant_frame` attribute.
179180
- `narwhals.Series`: use the `._compliant_series` attribute.
180-
- `narwhals.Expr`: call the `._to_compliant_expr` method, and pass to it the Narwhals-compliant namespace associated with
181+
- `narwhals.Expr`: call the `.__call__` method, and pass to it the Narwhals-compliant namespace associated with
181182
the given backend.
182183

183184
🛑 BUT WAIT! What's a Narwhals-compliant namespace?
@@ -212,9 +213,10 @@ pn = PandasLikeNamespace(
212213
implementation=Implementation.PANDAS,
213214
version=Version.MAIN,
214215
)
215-
expr = (nw.col("a") + 1)._to_compliant_expr(pn)
216+
expr = (nw.col("a") + 1)(pn)
216217
print(expr)
217218
```
219+
218220
If we then extract a Narwhals-compliant dataframe from `df` by
219221
calling `._compliant_frame`, we get a `PandasLikeDataFrame` - and that's an object which we can pass `expr` to!
220222

@@ -228,6 +230,7 @@ We can then view the underlying pandas Dataframe which was produced by calling `
228230
```python exec="1" result="python" session="pandas_api_mapping" source="above"
229231
print(result._native_frame)
230232
```
233+
231234
which is the same as we'd have obtained by just using the Narwhals API directly:
232235

233236
```python exec="1" result="python" session="pandas_api_mapping" source="above"
@@ -238,49 +241,42 @@ print(nw.to_native(df.select(nw.col("a") + 1)))
238241

239242
Group-by is probably one of Polars' most significant innovations (on the syntax side) with respect
240243
to pandas. We can write something like
244+
241245
```python
242246
df: pl.DataFrame
243247
df.group_by("a").agg((pl.col("c") > pl.col("b").mean()).max())
244248
```
249+
245250
To do this in pandas, we need to either use `GroupBy.apply` (sloooow), or do some crazy manual
246251
optimisations to get it to work.
247252

248253
In Narwhals, here's what we do:
249254

250255
- if somebody uses a simple group-by aggregation (e.g. `df.group_by('a').agg(nw.col('b').mean())`),
251256
then on the pandas side we translate it to
252-
```python
253-
df: pd.DataFrame
254-
df.groupby("a").agg({"b": ["mean"]})
255-
```
257+
258+
```python
259+
df: pd.DataFrame
260+
df.groupby("a").agg({"b": ["mean"]})
261+
```
262+
256263
- if somebody passes a complex group-by aggregation, then we use `apply` and raise a `UserWarning`, warning
257264
users of the performance penalty and advising them to refactor their code so that the aggregation they perform
258265
ends up being a simple one.
259266

260-
In order to tell whether an aggregation is simple, Narwhals uses the private `_depth` attribute of `PandasLikeExpr`:
261-
262-
```python exec="1" result="python" session="pandas_impl" source="above"
263-
print(pn.col("a").mean())
264-
print((pn.col("a") + 1).mean())
265-
```
266-
267-
For simple aggregations, Narwhals can just look at `_depth` and `function_name` and figure out
268-
which (efficient) elementary operation this corresponds to in pandas.
269-
270267
## Expression Metadata
271268

272-
Let's try printing out a few expressions to the console to see what they show us:
269+
Let's try printing out some compliant expressions' metadata to see what it shows us:
273270

274-
```python exec="1" result="python" session="metadata" source="above"
271+
```python exec="1" result="python" session="pandas_impl" source="above"
275272
import narwhals as nw
276273

277-
print(nw.col("a"))
278-
print(nw.col("a").mean())
279-
print(nw.col("a").mean().over("b"))
274+
print(nw.col("a")(pn)._metadata)
275+
print(nw.col("a").mean()(pn)._metadata)
276+
print(nw.col("a").mean().over("b")(pn)._metadata)
280277
```
281278

282-
Note how they tell us something about their metadata. This section is all about
283-
making sense of what that all means, what the rules are, and what it enables.
279+
This section is all about making sense of what that all means, what the rules are, and what it enables.
284280

285281
Here's a brief description of each piece of metadata:
286282

@@ -293,8 +289,6 @@ Here's a brief description of each piece of metadata:
293289
- `ExpansionKind.MULTI_UNNAMED`: Produces multiple outputs whose names depend
294290
on the input dataframe. For example, `nw.nth(0, 1)` or `nw.selectors.numeric()`.
295291

296-
- `last_node`: Kind of the last operation in the expression. See
297-
`narwhals._expression_parsing.ExprKind` for the various options.
298292
- `has_windows`: Whether the expression already contains an `over(...)` statement.
299293
- `n_orderable_ops`: How many order-dependent operations the expression contains.
300294

@@ -311,6 +305,7 @@ Here's a brief description of each piece of metadata:
311305
- `is_scalar_like`: Whether the output of the expression is always length-1.
312306
- `is_literal`: Whether the expression doesn't depend on any column but instead
313307
only on literal values, like `nw.lit(1)`.
308+
- `nodes`: List of operations which this expression applies when evaluated.
314309

315310
#### Chaining
316311

@@ -377,3 +372,67 @@ Narwhals triggers a broadcast in these situations:
377372

378373
Each backend is then responsible for doing its own broadcasting, as defined in each
379374
`CompliantExpr.broadcast` method.
375+
376+
### Elementwise push-down
377+
378+
SQL is picky about `over` operations. For example:
379+
380+
- `sum(a) over (partition by b)` is valid.
381+
- `sum(abs(a)) over (partition by b)` is valid.
382+
- `abs(sum(a)) over (partition by b)` is not valid.
383+
384+
In Polars, however, all three of
385+
386+
- `pl.col('a').sum().over('b')` is valid.
387+
- `pl.col('a').abs().sum().over('b')` is valid.
388+
- `pl.col('a').sum().abs().over('b')` is valid.
389+
390+
How can we retain Polars' level of flexibility when translating to SQL engines?
391+
392+
The answer is: by rewriting expressions. Specifically, we push down `over` nodes past elementwise ones.
393+
To see this, let's try printing the Narwhals equivalent of the last expression above (the one that SQL rejects):
394+
395+
```python exec="1" result="python" session="pushdown" source="above"
396+
import narwhals as nw
397+
398+
print(nw.col("a").sum().abs().over("b"))
399+
```
400+
401+
Note how Narwhals automatically inserted the `over` operation _before_ the `abs` one. In other words, instead
402+
of doing
403+
404+
- `sum` -> `abs` -> `over`
405+
406+
it did
407+
408+
- `sum` -> `over` -> `abs`
409+
410+
thus allowing the expression to be valid for SQL engines!
411+
412+
This is what we refer to as "pushing down `over` nodes". The idea is:
413+
414+
- Elementwise operations operate row-by-row and don't depend on the rows around them.
415+
- An `over` node partitions or orders a computation.
416+
- Therefore, an elementwise operation followed by an `over` operation is the same
417+
as doing the `over` operation followed by that same elementwise operation!
418+
419+
Note that the pushdown also applies to any arguments to the elementwise operation.
420+
For example, if we have
421+
422+
```python
423+
(nw.col("a").sum() + nw.col("b").sum()).over("c")
424+
```
425+
426+
then `+` is an elementwise operation and so can be swapped with `over`. We just need
427+
to take care to apply the `over` operation to all the arguments of `+`, so that we
428+
end up with
429+
430+
```python
431+
nw.col("a").sum().over("c") + nw.col("b").sum().over("c")
432+
```
433+
434+
In general, query optimisation is out-of-scope for Narwhals. We consider this
435+
expression rewrite acceptable because:
436+
437+
- It's simple.
438+
- It allows us to evaluate operations which otherwise wouldn't be allowed for certain backends.

narwhals/_arrow/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self:
424424
if subset is None:
425425
return self._with_native(self.native.drop_null(), validate_column_names=False)
426426
plx = self.__narwhals_namespace__()
427-
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
427+
mask = ~plx.any_horizontal(plx.col(subset).is_null(), ignore_nulls=True)
428428
return self.filter(mask)
429429

430430
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
@@ -496,7 +496,7 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
496496
plx._series.from_iterable(data, context=self, name=name)
497497
)
498498
else:
499-
rank = plx.col(order_by[0]).rank("ordinal", descending=False)
499+
rank = plx.col([order_by[0]]).rank("ordinal", descending=False)
500500
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
501501
return self.select(row_index, plx.all())
502502

narwhals/_arrow/expr.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,17 @@ def __init__(
3232
self,
3333
call: EvalSeries[ArrowDataFrame, ArrowSeries],
3434
*,
35-
depth: int,
36-
function_name: str,
3735
evaluate_output_names: EvalNames[ArrowDataFrame],
3836
alias_output_names: AliasNames | None,
3937
version: Version,
4038
scalar_kwargs: ScalarKwargs | None = None,
4139
implementation: Implementation | None = None,
4240
) -> None:
4341
self._call = call
44-
self._depth = depth
45-
self._function_name = function_name
46-
self._depth = depth
4742
self._evaluate_output_names = evaluate_output_names
4843
self._alias_output_names = alias_output_names
4944
self._version = version
50-
self._scalar_kwargs = scalar_kwargs or {}
51-
self._metadata: ExprMetadata | None = None
45+
self._opt_metadata: ExprMetadata | None = None
5246

5347
@classmethod
5448
def from_column_names(
@@ -57,7 +51,6 @@ def from_column_names(
5751
/,
5852
*,
5953
context: _LimitedContext,
60-
function_name: str = "",
6154
) -> Self:
6255
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
6356
try:
@@ -74,8 +67,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
7467

7568
return cls(
7669
func,
77-
depth=0,
78-
function_name=function_name,
7970
evaluate_output_names=evaluate_column_names,
8071
alias_output_names=None,
8172
version=context._version,
@@ -93,8 +84,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
9384

9485
return cls(
9586
func,
96-
depth=0,
97-
function_name="nth",
9887
evaluate_output_names=cls._eval_names_indices(column_indices),
9988
alias_output_names=None,
10089
version=context._version,
@@ -160,8 +149,6 @@ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
160149

161150
return self.__class__(
162151
func,
163-
depth=self._depth + 1,
164-
function_name=self._function_name + "->over",
165152
evaluate_output_names=self._evaluate_output_names,
166153
alias_output_names=self._alias_output_names,
167154
version=self._version,

narwhals/_arrow/group_by.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
7171
output_names, aliases = evaluate_output_names_and_aliases(
7272
expr, self.compliant, exclude
7373
)
74-
75-
if expr._depth == 0:
74+
md = expr._metadata
75+
if len(list(md.op_nodes_reversed())) == 1:
7676
# e.g. `agg(nw.len())`
77-
if expr._function_name != "len": # pragma: no cover
77+
if next(md.op_nodes_reversed()).name != "len": # pragma: no cover
7878
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
7979
raise AssertionError(msg)
8080

@@ -85,8 +85,8 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
8585

8686
function_name = self._leaf_name(expr)
8787
if function_name in {"std", "var"}:
88-
assert "ddof" in expr._scalar_kwargs # noqa: S101
89-
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
88+
last_node = next(md.op_nodes_reversed())
89+
option: Any = pc.VarianceOptions(**last_node.kwargs)
9090
elif function_name in {"len", "n_unique"}:
9191
option = pc.CountOptions(mode="all")
9292
elif function_name == "count":

0 commit comments

Comments
 (0)