diff --git a/docs/how_it_works.md b/docs/how_it_works.md index add65aacc6..7adff7a0a5 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -78,6 +78,7 @@ pn = PandasLikeNamespace( ) print(nw.col("a")._to_compliant_expr(pn)) ``` + The result from the last line above is the same as we'd get from `pn.col('a')`, and it's a `narwhals._pandas_like.expr.PandasLikeExpr` object, which we'll call `PandasLikeExpr` for short. @@ -215,6 +216,7 @@ pn = PandasLikeNamespace( expr = (nw.col("a") + 1)._to_compliant_expr(pn) print(expr) ``` + If we then extract a Narwhals-compliant dataframe from `df` by calling `._compliant_frame`, we get a `PandasLikeDataFrame` - and that's an object which we can pass `expr` to! @@ -228,6 +230,7 @@ We can then view the underlying pandas Dataframe which was produced by calling ` ```python exec="1" result="python" session="pandas_api_mapping" source="above" print(result._native_frame) ``` + which is the same as we'd have obtained by just using the Narwhals API directly: ```python exec="1" result="python" session="pandas_api_mapping" source="above" @@ -238,10 +241,12 @@ print(nw.to_native(df.select(nw.col("a") + 1))) Group-by is probably one of Polars' most significant innovations (on the syntax side) with respect to pandas. We can write something like + ```python df: pl.DataFrame df.group_by("a").agg((pl.col("c") > pl.col("b").mean()).max()) ``` + To do this in pandas, we need to either use `GroupBy.apply` (sloooow), or do some crazy manual optimisations to get it to work. @@ -249,38 +254,85 @@ In Narwhals, here's what we do: - if somebody uses a simple group-by aggregation (e.g. `df.group_by('a').agg(nw.col('b').mean())`), then on the pandas side we translate it to - ```python - df: pd.DataFrame - df.groupby("a").agg({"b": ["mean"]}) - ``` + + ```python + df: pd.DataFrame + df.groupby("a").agg({"b": ["mean"]}) + ``` + - if somebody passes a complex group-by aggregation, then we use `apply` and raise a `UserWarning`, warning users of the performance penalty and advising them to refactor their code so that the aggregation they perform ends up being a simple one. -In order to tell whether an aggregation is simple, Narwhals uses the private `_depth` attribute of `PandasLikeExpr`: +## Nodes + +If we have a Narwhals expression, we can look at the operations which make it up by accessing `_nodes`: + +```python exec="1" result="python" session="pandas_impl" source="above" +import narwhals as nw + +expr = nw.col("a").abs().std(ddof=1) + nw.col("b") +print(expr._nodes) +``` + +Each node represents an operation. Here, we have 4 operations: + +1. Given some dataframe, select column `'a'`. +2. Take its absolute value. +3. Take its standard deviation, with `ddof=1`. +4. Sum column `'b'`. + +Let's take a look at a couple of these nodes. Let's start with the third one: + +```python exec="1" result="python" session="pandas_impl" source="above" +print(expr._nodes[2].as_dict()) +``` + +This tells us a few things: + +- We're performing an aggregation. +- The name of the function is `'std'`. This will be looked up in the compliant object. +- It takes keyword arguments `ddof=1`. +- We'll look at `exprs`, `str_as_lit`, and `allow_multi_output` later. + +In order for the evaluation to succeed, then `PandasLikeExpr` must have a `std` method defined +on it, which takes a `ddof` argument. And this is what the `CompliantExpr` Protocol is for: so +long as a backend's implementation complies with the protocol, then Narwhals will be able to +unpack a `ExprNode` and turn it into a valid call. + +Let's take a look at the fourth node: + +```python exec="1" result="python" session="pandas_impl" source="above" +print(expr._nodes[3].as_dict()) +``` + +Note how now, the `exprs` attribute is populated. Indeed, we are summing another expression: `col('b')`. +The `exprs` parameter holds arguments which are either expressions, or should be interpreted as expressions. +The `str_as_lit` parameter tells us whether string literals should be interpreted as literals (e.g. `lit('foo')`) +or columns (e.g. `col('foo')`). Finally `allow_multi_output` tells us whether multi-outuput expressions +(more on this in the next section) are allowed to appear in `exprs`. + +Note that the expression in `exprs` also has its own nodes: ```python exec="1" result="python" session="pandas_impl" source="above" -print(pn.col("a").mean()) -print((pn.col("a") + 1).mean()) +print(expr._nodes[3].exprs[0]._nodes) ``` -For simple aggregations, Narwhals can just look at `_depth` and `function_name` and figure out -which (efficient) elementary operation this corresponds to in pandas. +It's nodes all the way down! ## Expression Metadata -Let's try printing out a few expressions to the console to see what they show us: +Let's try printing out some compliant expressions' metadata to see what it shows us: -```python exec="1" result="python" session="metadata" source="above" +```python exec="1" result="python" session="pandas_impl" source="above" import narwhals as nw -print(nw.col("a")) -print(nw.col("a").mean()) -print(nw.col("a").mean().over("b")) +print(nw.col("a")._to_compliant_expr(pn)._metadata) +print(nw.col("a").mean()._to_compliant_expr(pn)._metadata) +print(nw.col("a").mean().over("b")._to_compliant_expr(pn)._metadata) ``` -Note how they tell us something about their metadata. This section is all about -making sense of what that all means, what the rules are, and what it enables. +This section is all about making sense of what that all means, what the rules are, and what it enables. Here's a brief description of each piece of metadata: @@ -293,8 +345,6 @@ Here's a brief description of each piece of metadata: - `ExpansionKind.MULTI_UNNAMED`: Produces multiple outputs whose names depend on the input dataframe. For example, `nw.nth(0, 1)` or `nw.selectors.numeric()`. -- `last_node`: Kind of the last operation in the expression. See - `narwhals._expression_parsing.ExprKind` for the various options. - `has_windows`: Whether the expression already contains an `over(...)` statement. - `n_orderable_ops`: How many order-dependent operations the expression contains. @@ -311,8 +361,9 @@ Here's a brief description of each piece of metadata: - `is_scalar_like`: Whether the output of the expression is always length-1. - `is_literal`: Whether the expression doesn't depend on any column but instead only on literal values, like `nw.lit(1)`. +- `nodes`: List of operations which this expression applies when evaluated. -#### Chaining +### Chaining Say we have `expr.expr_method()`. How does `expr`'s `ExprMetadata` change? This depends on `expr_method`. Details can be found in `narwhals/_expression_parsing`, @@ -356,7 +407,7 @@ is: then `n_orderable_ops` is decreased by 1. This is the only way that `n_orderable_ops` can decrease. -### Broadcasting +## Broadcasting When performing comparisons between columns and aggregations or scalars, we operate as if the aggregation or scalar was broadcasted to the length of the whole column. For example, if we @@ -377,3 +428,67 @@ Narwhals triggers a broadcast in these situations: Each backend is then responsible for doing its own broadcasting, as defined in each `CompliantExpr.broadcast` method. + +## Elementwise push-down + +SQL is picky about `over` operations. For example: + +- `sum(a) over (partition by b)` is valid. +- `sum(abs(a)) over (partition by b)` is valid. +- `abs(sum(a)) over (partition by b)` is not valid. + +In Polars, however, all three of + +- `pl.col('a').sum().over('b')` is valid. +- `pl.col('a').abs().sum().over('b')` is valid. +- `pl.col('a').sum().abs().over('b')` is valid. + +How can we retain Polars' level of flexibility when translating to SQL engines? + +The answer is: by rewriting expressions. Specifically, we push down `over` nodes past elementwise ones. +To see this, let's try printing the Narwhals equivalent of the last expression above (the one that SQL rejects): + +```python exec="1" result="python" session="pushdown" source="above" +import narwhals as nw + +print(nw.col("a").sum().abs().over("b")) +``` + +Note how Narwhals automatically inserted the `over` operation _before_ the `abs` one. In other words, instead +of doing + +- `sum` -> `abs` -> `over` + +it did + +- `sum` -> `over` -> `abs` + +thus allowing the expression to be valid for SQL engines! + +This is what we refer to as "pushing down `over` nodes". The idea is: + +- Elementwise operations operate row-by-row and don't depend on the rows around them. +- An `over` node partitions or orders a computation. +- Therefore, an elementwise operation followed by an `over` operation is the same + as doing the `over` operation followed by that same elementwise operation! + +Note that the pushdown also applies to any arguments to the elementwise operation. +For example, if we have + +```python +(nw.col("a").sum() + nw.col("b").sum()).over("c") +``` + +then `+` is an elementwise operation and so can be swapped with `over`. We just need +to take care to apply the `over` operation to all the arguments of `+`, so that we +end up with + +```python +nw.col("a").sum().over("c") + nw.col("b").sum().over("c") +``` + +!!! info + In general, query optimisation is out-of-scope for Narwhals. We consider this + expression rewrite acceptable because: + - It's simple. + - It allows us to evaluate operations which otherwise wouldn't be allowed for certain backends. diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index effc37cc4e..034cf28759 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -9,7 +9,6 @@ from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import concat_tables, native_to_narwhals_dtype, repeat from narwhals._compliant import EagerDataFrame -from narwhals._expression_parsing import ExprKind from narwhals._utils import ( Implementation, Version, @@ -330,7 +329,7 @@ def simple_select(self, *column_names: str) -> Self: ) def select(self, *exprs: ArrowExpr) -> Self: - new_series = self._evaluate_into_exprs(*exprs) + new_series = self._evaluate_exprs(*exprs) if not new_series: # return empty dataframe, like Polars does return self._with_native( @@ -357,7 +356,7 @@ def with_columns(self, *exprs: ArrowExpr) -> Self: # NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame) # All `pyarrow` data is immutable, so this is fine native_frame = self.native - new_columns = self._evaluate_into_exprs(*exprs) + new_columns = self._evaluate_exprs(*exprs) columns = self.columns for col_value in new_columns: @@ -402,12 +401,10 @@ def join( ) return self._with_native( - self.with_columns( - plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) - ) + self.with_columns(plx.lit(0, None).alias(key_token).broadcast()) .native.join( other.with_columns( - plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) + plx.lit(0, None).alias(key_token).broadcast() ).native, keys=key_token, right_keys=key_token, @@ -517,8 +514,7 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: return self.select(row_index, plx.all()) def filter(self, predicate: ArrowExpr) -> Self: - # `[0]` is safe as the predicate's expression only returns a single column - mask_native = self._evaluate_into_exprs(predicate)[0].native + mask_native = self._evaluate_single_output_expr(predicate).native return self._with_native( self.native.filter(mask_native), validate_column_names=False ) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 63c048837f..0e59cf039b 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -21,8 +21,7 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.namespace import ArrowNamespace - from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs - from narwhals._expression_parsing import ExprMetadata + from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries from narwhals._utils import Version, _LimitedContext @@ -33,23 +32,15 @@ def __init__( self, call: EvalSeries[ArrowDataFrame, ArrowSeries], *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[ArrowDataFrame], alias_output_names: AliasNames | None, version: Version, - scalar_kwargs: ScalarKwargs | None = None, - implementation: Implementation | None = None, + implementation: Implementation = Implementation.PYARROW, ) -> None: self._call = call - self._depth = depth - self._function_name = function_name - self._depth = depth self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._scalar_kwargs = scalar_kwargs or {} - self._metadata: ExprMetadata | None = None @classmethod def from_column_names( @@ -58,7 +49,6 @@ def from_column_names( /, *, context: _LimitedContext, - function_name: str = "", ) -> Self: def func(df: ArrowDataFrame) -> list[ArrowSeries]: try: @@ -75,8 +65,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return cls( func, - depth=0, - function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, @@ -94,8 +82,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return cls( func, - depth=0, - function_name="nth", evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, version=context._version, @@ -113,7 +99,7 @@ def _reuse_series_extra_kwargs( def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: meta = self._metadata - if partition_by and meta is not None and not meta.is_scalar_like: + if partition_by and not meta.is_scalar_like: msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow." raise NotImplementedError(msg) @@ -167,8 +153,6 @@ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 47bd15a37d..22785f0ebe 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -76,8 +76,9 @@ def _configure_agg( ) -> tuple[pa.TableGroupBy, Aggregation, AggregateOptions | None]: option: AggregateOptions | None = None function_name = self._leaf_name(expr) + kwargs = self._kwargs(expr) if function_name in self._OPTION_VARIANCE: - ddof = expr._scalar_kwargs.get("ddof", 1) + ddof = kwargs["ddof"] option = pc.VarianceOptions(ddof=ddof) elif function_name in self._OPTION_COUNT_ALL: option = pc.CountOptions(mode="all") @@ -128,10 +129,11 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: output_names, aliases = evaluate_output_names_and_aliases( expr, self.compliant, exclude ) - - if expr._depth == 0: + md = expr._metadata + op_nodes_reversed = list(md.op_nodes_reversed()) + if len(op_nodes_reversed) == 1: # e.g. `agg(nw.len())` - if expr._function_name != "len": # pragma: no cover + if op_nodes_reversed[0].name != "len": # pragma: no cover msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 02a2417014..98282a575e 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -13,7 +13,7 @@ from narwhals._arrow.selectors import ArrowSelectorNamespace from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import cast_to_comparable_string_types -from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen +from narwhals._compliant import EagerNamespace from narwhals._expression_parsing import ( combine_alias_output_names, combine_evaluate_output_names, @@ -23,13 +23,7 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence - from narwhals._arrow.typing import ( - ArrayOrScalar, - ChunkedArrayAny, - Incomplete, - ScalarAny, - ) - from narwhals._compliant.typing import ScalarKwargs + from narwhals._arrow.typing import ChunkedArrayAny, Incomplete, ScalarAny from narwhals._utils import Version from narwhals.typing import IntoDType, NonNestedLiteral @@ -65,8 +59,6 @@ def len(self) -> ArrowExpr: lambda df: [ ArrowSeries.from_iterable([len(df.native)], name="len", context=self) ], - depth=0, - function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, version=self._version, @@ -83,8 +75,6 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: return self._expr( lambda df: [_lit_arrow_series(df)], - depth=0, - function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, version=self._version, @@ -99,8 +89,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -115,8 +103,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -130,8 +116,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -148,8 +132,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -167,8 +149,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -186,8 +166,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -220,9 +198,6 @@ def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: def selectors(self) -> ArrowSelectorNamespace: return ArrowSelectorNamespace.from_namespace(self) - def when(self, predicate: ArrowExpr) -> ArrowWhen: - return ArrowWhen.from_expr(predicate, context=self) - def concat_str( self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool ) -> ArrowExpr: @@ -247,8 +222,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -268,33 +241,16 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="coalesce", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) - -class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]): - @property - def _then(self) -> type[ArrowThen]: - return ArrowThen - def _if_then_else( self, when: ChunkedArrayAny, then: ChunkedArrayAny, - otherwise: ArrayOrScalar | NonNestedLiteral, - /, + otherwise: ChunkedArrayAny | None = None, ) -> ChunkedArrayAny: otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise return pc.if_else(when, then, otherwise) - - -class ArrowThen( - CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr, ArrowWhen], ArrowExpr -): - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "whenthen" diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 459e0022bb..62bb014083 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: from narwhals._arrow.dataframe import ArrowDataFrame # noqa: F401 from narwhals._arrow.series import ArrowSeries # noqa: F401 - from narwhals._compliant.typing import ScalarKwargs class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]): @@ -18,15 +17,9 @@ def _selector(self) -> type[ArrowSelector]: class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc] - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "selector" - def _to_expr(self) -> ArrowExpr: return ArrowExpr( self._call, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index e4761cc681..9fec53f093 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -24,7 +24,6 @@ zeros, ) from narwhals._compliant import EagerSeries, EagerSeriesHist -from narwhals._expression_parsing import ExprKind from narwhals._typing_compat import assert_never from narwhals._utils import ( Implementation, @@ -68,12 +67,10 @@ IntoDType, ModeKeepStrategy, NonNestedLiteral, - NumericLiteral, PythonLiteral, RankMethod, RollingInterpolationMethod, SizedMultiIndexSelector, - TemporalLiteral, _1DArray, _2DArray, _SliceIndex, @@ -491,9 +488,9 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: return self.native.to_numpy() def alias(self, name: str) -> Self: - result = self.__class__(self.native, name=name, version=self._version) - result._broadcast = self._broadcast - return result + ret = self.__class__(self.native, name=name, version=self._version) + ret._broadcast = self._broadcast + return ret @property def dtype(self) -> DType: @@ -846,26 +843,21 @@ def quantile( def gather_every(self, n: int, offset: int = 0) -> Self: return self._with_native(self.native[offset::n]) - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: - _, lower = ( - extract_native(self, lower_bound) if lower_bound is not None else (None, None) - ) - _, upper = ( - extract_native(self, upper_bound) if upper_bound is not None else (None, None) - ) - - if lower is None: - return self._with_native(pc.min_element_wise(self.native, upper)) - if upper is None: - return self._with_native(pc.max_element_wise(self.native, lower)) + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: + _, lower = extract_native(self, lower_bound) + _, upper = extract_native(self, upper_bound) return self._with_native( pc.max_element_wise(pc.min_element_wise(self.native, upper), lower) ) + def clip_lower(self, lower_bound: Self) -> Self: + _, lower = extract_native(self, lower_bound) + return self._with_native(pc.max_element_wise(self.native, lower)) + + def clip_upper(self, upper_bound: Self) -> Self: + _, upper = extract_native(self, upper_bound) + return self._with_native(pc.min_element_wise(self.native, upper)) + def to_arrow(self) -> ArrayAny: return self.native.combine_chunks() @@ -876,8 +868,7 @@ def mode(self, *, keep: ModeKeepStrategy) -> ArrowSeries: name=col_token, normalize=False, sort=False, parallel=False ) result = counts.filter( - plx.col(col_token) - == plx.col(col_token).max().broadcast(kind=ExprKind.AGGREGATION) + plx.col(col_token) == plx.col(col_token).max().broadcast() ).get_column(self.name) return result.head(1) if keep == "any" else result diff --git a/narwhals/_arrow/series_str.py b/narwhals/_arrow/series_str.py index 6768a1d92a..a571d4eabd 100644 --- a/narwhals/_arrow/series_str.py +++ b/narwhals/_arrow/series_str.py @@ -6,7 +6,12 @@ import pyarrow as pa import pyarrow.compute as pc -from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format +from narwhals._arrow.utils import ( + ArrowSeriesNamespace, + extract_native, + lit, + parse_datetime_format, +) from narwhals._compliant.any_namespace import StringNamespace if TYPE_CHECKING: @@ -18,25 +23,23 @@ class ArrowSeriesStringNamespace(ArrowSeriesNamespace, StringNamespace["ArrowSer def len_chars(self) -> ArrowSeries: return self.with_native(pc.utf8_length(self.native)) - def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> ArrowSeries: + def replace( + self, value: ArrowSeries, pattern: str, *, literal: bool, n: int + ) -> ArrowSeries: fn = pc.replace_substring if literal else pc.replace_substring_regex - try: - arr = fn(self.native, pattern, replacement=value, max_replacements=n) - except TypeError as e: - if not isinstance(value, str): - msg = "PyArrow backed `.str.replace` only supports str replacement values" - raise TypeError(msg) from e - raise + _, value_native = extract_native(self.compliant, value) + if not isinstance(value_native, pa.StringScalar): + msg = "PyArrow backed `.str.replace` only supports str replacement values" + raise TypeError(msg) + arr = fn( + self.native, pattern, replacement=value_native.as_py(), max_replacements=n + ) return self.with_native(arr) - def replace_all(self, pattern: str, value: str, *, literal: bool) -> ArrowSeries: - try: - return self.replace(pattern, value, literal=literal, n=-1) - except TypeError as e: - if not isinstance(value, str): - msg = "PyArrow backed `.str.replace_all` only supports str replacement values." - raise TypeError(msg) from e - raise + def replace_all( + self, value: ArrowSeries, pattern: str, *, literal: bool + ) -> ArrowSeries: + return self.replace(value, pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> ArrowSeries: return self.with_native( diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index 70bd22588b..600ca4e1fd 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -53,7 +53,6 @@ NativeFrameT_co, NativeSeriesT_co, ) -from narwhals._compliant.when_then import CompliantThen, CompliantWhen, EagerWhen from narwhals._compliant.window import WindowInputs __all__ = [ @@ -70,8 +69,6 @@ "CompliantSeries", "CompliantSeriesOrNativeExprT_co", "CompliantSeriesT", - "CompliantThen", - "CompliantWhen", "DepthTrackingExpr", "DepthTrackingGroupBy", "DepthTrackingNamespace", @@ -90,7 +87,6 @@ "EagerSeriesStringNamespace", "EagerSeriesStructNamespace", "EagerSeriesT", - "EagerWhen", "EvalNames", "EvalSeries", "LazyExpr", diff --git a/narwhals/_compliant/any_namespace.py b/narwhals/_compliant/any_namespace.py index 7538c16155..b7e48a273f 100644 --- a/narwhals/_compliant/any_namespace.py +++ b/narwhals/_compliant/any_namespace.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, Protocol +from typing import TYPE_CHECKING, ClassVar, Protocol, TypeVar from narwhals._utils import CompliantT_co, _StoresCompliant @@ -12,6 +12,8 @@ from narwhals._compliant.typing import Accessor from narwhals.typing import NonNestedLiteral, TimeUnit +T = TypeVar("T") + __all__ = [ "CatNamespace", "DateTimeNamespace", @@ -81,28 +83,24 @@ def to_lowercase(self) -> CompliantT_co: ... def to_uppercase(self) -> CompliantT_co: ... -class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): +class StringNamespace(_StoresCompliant[T], Protocol[T]): _accessor: ClassVar[Accessor] = "str" - def len_chars(self) -> CompliantT_co: ... - def replace( - self, pattern: str, value: str, *, literal: bool, n: int - ) -> CompliantT_co: ... - def replace_all( - self, pattern: str, value: str, *, literal: bool - ) -> CompliantT_co: ... - def strip_chars(self, characters: str | None) -> CompliantT_co: ... - def starts_with(self, prefix: str) -> CompliantT_co: ... - def ends_with(self, suffix: str) -> CompliantT_co: ... - def contains(self, pattern: str, *, literal: bool) -> CompliantT_co: ... - def slice(self, offset: int, length: int | None) -> CompliantT_co: ... - def split(self, by: str) -> CompliantT_co: ... - def to_datetime(self, format: str | None) -> CompliantT_co: ... - def to_date(self, format: str | None) -> CompliantT_co: ... - def to_lowercase(self) -> CompliantT_co: ... - def to_titlecase(self) -> CompliantT_co: ... - def to_uppercase(self) -> CompliantT_co: ... - def zfill(self, width: int) -> CompliantT_co: ... + def len_chars(self) -> T: ... + def replace(self, value: T, pattern: str, *, literal: bool, n: int) -> T: ... + def replace_all(self, value: T, pattern: str, *, literal: bool) -> T: ... + def strip_chars(self, characters: str | None) -> T: ... + def starts_with(self, prefix: str) -> T: ... + def ends_with(self, suffix: str) -> T: ... + def contains(self, pattern: str, *, literal: bool) -> T: ... + def slice(self, offset: int, length: int | None) -> T: ... + def split(self, by: str) -> T: ... + def to_datetime(self, format: str | None) -> T: ... + def to_date(self, format: str | None) -> T: ... + def to_lowercase(self) -> T: ... + def to_titlecase(self) -> T: ... + def to_uppercase(self) -> T: ... + def zfill(self, width: int) -> T: ... class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index 20f7f08f5c..5561a48a4d 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -21,10 +21,7 @@ FillNullStrategy, IntoDType, ModeKeepStrategy, - NonNestedLiteral, - NumericLiteral, RankMethod, - TemporalLiteral, ) __all__ = ["CompliantColumn"] @@ -35,38 +32,36 @@ class CompliantColumn(Protocol): _version: Version - def __add__(self, other: Any) -> Self: ... - def __and__(self, other: Any) -> Self: ... - def __eq__(self, other: object) -> Self: ... # type: ignore[override] - def __floordiv__(self, other: Any) -> Self: ... - def __ge__(self, other: Any) -> Self: ... - def __gt__(self, other: Any) -> Self: ... + def __add__(self, other: Self) -> Self: ... + def __and__(self, other: Self) -> Self: ... + def __eq__(self, other: Self) -> Self: ... # type: ignore[override] + def __floordiv__(self, other: Self) -> Self: ... + def __ge__(self, other: Self) -> Self: ... + def __gt__(self, other: Self) -> Self: ... def __invert__(self) -> Self: ... - def __le__(self, other: Any) -> Self: ... - def __lt__(self, other: Any) -> Self: ... - def __mod__(self, other: Any) -> Self: ... - def __mul__(self, other: Any) -> Self: ... - def __ne__(self, other: object) -> Self: ... # type: ignore[override] - def __or__(self, other: Any) -> Self: ... - def __pow__(self, other: Any) -> Self: ... - def __rfloordiv__(self, other: Any) -> Self: ... - def __rmod__(self, other: Any) -> Self: ... - def __rpow__(self, other: Any) -> Self: ... - def __rsub__(self, other: Any) -> Self: ... - def __rtruediv__(self, other: Any) -> Self: ... - def __sub__(self, other: Any) -> Self: ... - def __truediv__(self, other: Any) -> Self: ... + def __le__(self, other: Self) -> Self: ... + def __lt__(self, other: Self) -> Self: ... + def __mod__(self, other: Self) -> Self: ... + def __mul__(self, other: Self) -> Self: ... + def __ne__(self, other: Self) -> Self: ... # type: ignore[override] + def __or__(self, other: Self) -> Self: ... + def __pow__(self, other: Self) -> Self: ... + def __rfloordiv__(self, other: Self) -> Self: ... + def __rmod__(self, other: Self) -> Self: ... + def __rpow__(self, other: Self) -> Self: ... + def __rsub__(self, other: Self) -> Self: ... + def __rtruediv__(self, other: Self) -> Self: ... + def __sub__(self, other: Self) -> Self: ... + def __truediv__(self, other: Self) -> Self: ... def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ... def abs(self) -> Self: ... def alias(self, name: str) -> Self: ... def cast(self, dtype: IntoDType) -> Self: ... - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: ... + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: ... + def clip_lower(self, lower_bound: Self) -> Self: ... + def clip_upper(self, upper_bound: Self) -> Self: ... def cum_count(self, *, reverse: bool) -> Self: ... def cum_max(self, *, reverse: bool) -> Self: ... def cum_min(self, *, reverse: bool) -> Self: ... @@ -89,10 +84,7 @@ def exp(self) -> Self: ... def sqrt(self) -> Self: ... def fill_nan(self, value: float | None) -> Self: ... def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: ... def is_between( self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval @@ -105,66 +97,6 @@ def is_between( return (self > lower_bound) & (self < upper_bound) return (self >= lower_bound) & (self <= upper_bound) - def is_close( - self, - other: Self | NumericLiteral, - *, - abs_tol: float, - rel_tol: float, - nans_equal: bool, - ) -> Self: - from decimal import Decimal - - other_abs: Self | NumericLiteral - other_is_nan: Self | bool - other_is_inf: Self | bool - other_is_not_inf: Self | bool - - if isinstance(other, (float, int, Decimal)): - from math import isinf, isnan - - # NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447 - other_abs = other.__abs__() - other_is_nan = isnan(other) - other_is_inf = isinf(other) - - # Define the other_is_not_inf variable to prevent triggering the following warning: - # > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be - # > removed in Python 3.16. - other_is_not_inf = not other_is_inf - - else: - other_abs, other_is_nan = other.abs(), other.is_nan() - other_is_not_inf = other.is_finite() | other_is_nan - other_is_inf = ~other_is_not_inf - - rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol - tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None) - - self_is_nan = self.is_nan() - self_is_not_inf = self.is_finite() | self_is_nan - - # Values are close if abs_diff <= tolerance, and both finite - is_close = ( - ((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf - ) - - # Handle infinity cases: infinities are close/equal if they have the same sign - self_sign, other_sign = self > 0, other > 0 - is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign) - - # Handle nan cases: - # * If any value is NaN, then False (via `& ~either_nan`) - # * However, if `nans_equals = True` and if _both_ values are NaN, then True - either_nan = self_is_nan | other_is_nan - result = (is_close | is_same_inf) & ~either_nan - - if nans_equal: - both_nan = self_is_nan & other_is_nan - result = result | both_nan - - return result - def is_duplicated(self) -> Self: return ~self.is_unique() diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 7f4cc017e2..03361545bf 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -37,6 +37,7 @@ is_slice_index, is_slice_none, ) +from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: from io import BytesIO @@ -153,7 +154,6 @@ def simple_select(self, *column_names: str) -> Self: def sort( self, *by: str, descending: bool | Sequence[bool], nulls_last: bool ) -> Self: ... - def tail(self, n: int) -> Self: ... def unique( self, subset: Sequence[str] | None, @@ -347,21 +347,24 @@ def _with_native( def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: return check_columns_exist(subset, available=self.columns) - def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT: + def _evaluate_single_output_expr(self, expr: EagerExprT, /) -> EagerSeriesT: """Evaluate `expr` and ensure it has a **single** output.""" - result: Sequence[EagerSeriesT] = expr(self) - assert len(result) == 1 # debug assertion # noqa: S101 + # NOTE: Ignore intermittent [False Negative] + # Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr" + # Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame" + result = self._evaluate_expr(expr) # pyright: ignore[reportArgumentType] + if len(result) != 1: # pragma: no cover + msg = "multi-output expressions not allowed in this context" + raise MultiOutputExpressionError(msg) return result[0] - def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: + def _evaluate_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: # NOTE: Ignore intermittent [False Negative] # Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr" # Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame" - return tuple( - chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs) # pyright: ignore[reportArgumentType] - ) + return tuple(chain.from_iterable(self._evaluate_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType] - def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: + def _evaluate_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: """Return list of raw columns. For eager backends we alias operations at each step. diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 05e371f988..997746d429 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -35,6 +35,7 @@ zip_strict, ) from narwhals.dependencies import is_numpy_array, is_numpy_scalar +from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -43,8 +44,8 @@ from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace from narwhals._compliant.series import CompliantSeries - from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs - from narwhals._expression_parsing import ExprKind, ExprMetadata + from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries + from narwhals._expression_parsing import ExprMetadata from narwhals._utils import Implementation, Version, _LimitedContext from narwhals.typing import ( ClosedInterval, @@ -52,10 +53,8 @@ IntoDType, ModeKeepStrategy, NonNestedLiteral, - NumericLiteral, RankMethod, RollingInterpolationMethod, - TemporalLiteral, TimeUnit, ) @@ -90,7 +89,22 @@ class CompliantExpr( _implementation: Implementation _evaluate_output_names: EvalNames[CompliantFrameT] _alias_output_names: AliasNames | None - _metadata: ExprMetadata | None + # This should be set with extreme care, only in `_expression_parsing.py`, + # and never from within any compliant class. + _opt_metadata: ExprMetadata | None = None + + @property + def _metadata(self) -> ExprMetadata: + if self._opt_metadata is None: # pragma: no cover + msg = ( + "`_opt_metadata` is None. This is usually the result of trying to do " + "some operation (such as `over`) which requires access to the metadata " + "at the compliant level. You may want to consider rewriting your logic " + "so that this operation is not necessary. Ideally you should avoid " + "setting `_opt_metadata` manually." + ) + raise AssertionError(msg) + return self._opt_metadata def __call__( self, df: CompliantFrameT @@ -111,11 +125,10 @@ def from_column_names( *, context: _LimitedContext, ) -> Self: ... - def broadcast( - self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL] - ) -> Self: ... + def broadcast(self) -> Self: ... # NOTE: `polars` + def alias(self, name: str) -> Self: ... def all(self) -> Self: ... def any(self) -> Self: ... def count(self) -> Self: ... @@ -172,9 +185,6 @@ class DepthTrackingExpr( ImplExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co], Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co], ): - _depth: int - _function_name: str - # NOTE: pyright bug? # Method "from_column_names" overrides class "CompliantExpr" in an incompatible manner # Parameter 2 type mismatch: base parameter is type "EvalNames[CompliantFrameT@DepthTrackingExpr]", override parameter is type "EvalNames[CompliantFrameT@DepthTrackingExpr]" @@ -188,7 +198,6 @@ def from_column_names( # pyright: ignore[reportIncompatibleMethodOverride] /, *, context: _LimitedContext, - function_name: str = "", ) -> Self: ... def _is_elementary(self) -> bool: @@ -206,10 +215,7 @@ def _is_elementary(self) -> bool: Elementary expressions are the only ones supported properly in pandas, PyArrow, and Dask. """ - return self._depth < 2 - - def __repr__(self) -> str: # pragma: no cover - return f"{type(self).__name__}(depth={self._depth}, function_name={self._function_name})" + return len(list(self._metadata.op_nodes_reversed())) <= 2 class EagerExpr( @@ -217,19 +223,15 @@ class EagerExpr( Protocol[EagerDataFrameT, EagerSeriesT], ): _call: EvalSeries[EagerDataFrameT, EagerSeriesT] - _scalar_kwargs: ScalarKwargs def __init__( self, call: EvalSeries[EagerDataFrameT, EagerSeriesT], *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[EagerDataFrameT], alias_output_names: AliasNames | None, implementation: Implementation, version: Version, - scalar_kwargs: ScalarKwargs | None = None, ) -> None: ... def __call__(self, df: EagerDataFrameT) -> Sequence[EagerSeriesT]: @@ -243,30 +245,22 @@ def _from_callable( cls, func: EvalSeries[EagerDataFrameT, EagerSeriesT], *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[EagerDataFrameT], alias_output_names: AliasNames | None, context: _LimitedContext, - scalar_kwargs: ScalarKwargs | None = None, ) -> Self: return cls( func, - depth=depth, - function_name=function_name, evaluate_output_names=evaluate_output_names, alias_output_names=alias_output_names, implementation=context._implementation, version=context._version, - scalar_kwargs=scalar_kwargs, ) @classmethod def _from_series(cls, series: EagerSeriesT) -> Self: return cls( lambda _df: [series], - depth=0, - function_name="series", evaluate_output_names=lambda _df: [series.name], alias_output_names=None, implementation=series._implementation, @@ -302,22 +296,14 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]: return self.__class__( func, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, implementation=self._implementation, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) def _reuse_series( - self, - method_name: str, - *, - returns_scalar: bool = False, - scalar_kwargs: ScalarKwargs | None = None, - **expressifiable_args: Any, + self, method_name: str, *, returns_scalar: bool = False, **kwargs: Any ) -> Self: """Reuse Series implementation for expression. @@ -328,25 +314,18 @@ def _reuse_series( method_name: name of method. returns_scalar: whether the Series version returns a scalar. In this case, the expression version should return a 1-row Series. - scalar_kwargs: non-expressifiable args which we may need to reuse in `agg` or `over`, - such as `ddof` for `std` and `var`. - expressifiable_args: keyword arguments to pass to function, which may - be expressifiable (e.g. `nw.col('a').is_between(3, nw.col('b')))`). + kwargs: keyword arguments to pass to function. """ func = partial( self._reuse_series_inner, method_name=method_name, returns_scalar=returns_scalar, - scalar_kwargs=scalar_kwargs or {}, - expressifiable_args=expressifiable_args, + **kwargs, ) return self._from_callable( func, - depth=self._depth + 1, - function_name=f"{self._function_name}->{method_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, - scalar_kwargs=scalar_kwargs, context=self, ) @@ -367,15 +346,15 @@ def _reuse_series_inner( *, method_name: str, returns_scalar: bool, - scalar_kwargs: ScalarKwargs, - expressifiable_args: dict[str, Any], + **kwargs: Any, ) -> Sequence[EagerSeriesT]: kwargs = { - **scalar_kwargs, **{ - name: df._evaluate_expr(value) if self._is_expr(value) else value - for name, value in expressifiable_args.items() - }, + name: df._evaluate_single_output_expr(value) + if self._is_expr(value) + else value + for name, value in kwargs.items() + } } method = methodcaller( method_name, @@ -417,7 +396,9 @@ def _reuse_series_namespace( def inner(df: EagerDataFrameT) -> list[EagerSeriesT]: kwargs = { - name: df._evaluate_expr(value) if self._is_expr(value) else value + name: df._evaluate_single_output_expr(value) + if self._is_expr(value) + else value for name, value in expressifiable_args.items() } return [ @@ -427,15 +408,12 @@ def inner(df: EagerDataFrameT) -> list[EagerSeriesT]: return self._from_callable( inner, - depth=self._depth + 1, - function_name=f"{self._function_name}->{series_namespace}.{method_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, - scalar_kwargs=self._scalar_kwargs, context=self, ) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + def broadcast(self) -> Self: # Mark the resulting Series with `_broadcast = True`. # Then, when extracting native objects, `extract_native` will # know what to do. @@ -448,82 +426,79 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]: return type(self)( func, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, implementation=self._implementation, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) def cast(self, dtype: IntoDType) -> Self: return self._reuse_series("cast", dtype=dtype) - def _with_binary(self, operator: str, other: Self | Any, /) -> Self: + def _with_binary(self, operator: str, other: Self, /) -> Self: return self._reuse_series(operator, other=other) - def _with_binary_right(self, operator: str, other: Self | Any, /) -> Self: + def _with_binary_right(self, operator: str, other: Self, /) -> Self: return self.alias("literal")._reuse_series(operator, other=other) - def __eq__(self, other: Self | Any) -> Self: # type: ignore[override] + def __eq__(self, other: Self) -> Self: # type: ignore[override] return self._with_binary("__eq__", other) - def __ne__(self, other: Self | Any) -> Self: # type: ignore[override] + def __ne__(self, other: Self) -> Self: # type: ignore[override] return self._with_binary("__ne__", other) - def __ge__(self, other: Self | Any) -> Self: + def __ge__(self, other: Self) -> Self: return self._with_binary("__ge__", other) - def __gt__(self, other: Self | Any) -> Self: + def __gt__(self, other: Self) -> Self: return self._with_binary("__gt__", other) - def __le__(self, other: Self | Any) -> Self: + def __le__(self, other: Self) -> Self: return self._with_binary("__le__", other) - def __lt__(self, other: Self | Any) -> Self: + def __lt__(self, other: Self) -> Self: return self._with_binary("__lt__", other) - def __and__(self, other: Self | bool | Any) -> Self: + def __and__(self, other: Self) -> Self: return self._with_binary("__and__", other) - def __or__(self, other: Self | bool | Any) -> Self: + def __or__(self, other: Self) -> Self: return self._with_binary("__or__", other) - def __add__(self, other: Self | Any) -> Self: + def __add__(self, other: Self) -> Self: return self._with_binary("__add__", other) - def __sub__(self, other: Self | Any) -> Self: + def __sub__(self, other: Self) -> Self: return self._with_binary("__sub__", other) - def __rsub__(self, other: Self | Any) -> Self: + def __rsub__(self, other: Self) -> Self: return self._with_binary_right("__rsub__", other) - def __mul__(self, other: Self | Any) -> Self: + def __mul__(self, other: Self) -> Self: return self._with_binary("__mul__", other) - def __truediv__(self, other: Self | Any) -> Self: + def __truediv__(self, other: Self) -> Self: return self._with_binary("__truediv__", other) - def __rtruediv__(self, other: Self | Any) -> Self: + def __rtruediv__(self, other: Self) -> Self: return self._with_binary_right("__rtruediv__", other) - def __floordiv__(self, other: Self | Any) -> Self: + def __floordiv__(self, other: Self) -> Self: return self._with_binary("__floordiv__", other) - def __rfloordiv__(self, other: Self | Any) -> Self: + def __rfloordiv__(self, other: Self) -> Self: return self._with_binary_right("__rfloordiv__", other) - def __pow__(self, other: Self | Any) -> Self: + def __pow__(self, other: Self) -> Self: return self._with_binary("__pow__", other) - def __rpow__(self, other: Self | Any) -> Self: + def __rpow__(self, other: Self) -> Self: return self._with_binary_right("__rpow__", other) - def __mod__(self, other: Self | Any) -> Self: + def __mod__(self, other: Self) -> Self: return self._with_binary("__mod__", other) - def __rmod__(self, other: Self | Any) -> Self: + def __rmod__(self, other: Self) -> Self: return self._with_binary_right("__rmod__", other) # Unary @@ -550,14 +525,10 @@ def median(self) -> Self: return self._reuse_series("median", returns_scalar=True) def std(self, *, ddof: int) -> Self: - return self._reuse_series( - "std", returns_scalar=True, scalar_kwargs={"ddof": ddof} - ) + return self._reuse_series("std", returns_scalar=True, ddof=ddof) def var(self, *, ddof: int) -> Self: - return self._reuse_series( - "var", returns_scalar=True, scalar_kwargs={"ddof": ddof} - ) + return self._reuse_series("var", returns_scalar=True, ddof=ddof) def skew(self) -> Self: return self._reuse_series("skew", returns_scalar=True) @@ -585,15 +556,17 @@ def arg_max(self) -> Self: # Other - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: return self._reuse_series( "clip", lower_bound=lower_bound, upper_bound=upper_bound ) + def clip_lower(self, lower_bound: Self) -> Self: + return self._reuse_series("clip_lower", lower_bound=lower_bound) + + def clip_upper(self, upper_bound: Self) -> Self: + return self._reuse_series("clip_upper", upper_bound=upper_bound) + def is_null(self) -> Self: return self._reuse_series("is_null") @@ -604,13 +577,10 @@ def fill_nan(self, value: float | None) -> Self: return self._reuse_series("fill_nan", value=value) def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: return self._reuse_series( - "fill_null", value=value, scalar_kwargs={"strategy": strategy, "limit": limit} + "fill_null", value=value, strategy=strategy, limit=limit ) def is_in(self, other: Any) -> Self: @@ -666,20 +636,15 @@ def alias(self, name: str) -> Self: def alias_output_names(names: Sequence[str]) -> Sequence[str]: if len(names) != 1: msg = f"Expected function with single output, found output names: {names}" - raise ValueError(msg) + raise MultiOutputExpressionError(msg) return [name] - # Define this one manually, so that we can - # override `output_names` and not increase depth return type(self)( lambda df: [series.alias(name) for series in self(df)], - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, implementation=self._implementation, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) def is_unique(self) -> Self: @@ -697,14 +662,15 @@ def quantile( return self._reuse_series( "quantile", returns_scalar=True, - scalar_kwargs={"quantile": quantile, "interpolation": interpolation}, + quantile=quantile, + interpolation=interpolation, ) def head(self, n: int) -> Self: - return self._reuse_series("head", scalar_kwargs={"n": n}) + return self._reuse_series("head", n=n) def tail(self, n: int) -> Self: - return self._reuse_series("tail", scalar_kwargs={"n": n}) + return self._reuse_series("tail", n=n) def round(self, decimals: int) -> Self: return self._reuse_series("round", decimals=decimals) @@ -722,7 +688,7 @@ def gather_every(self, n: int, offset: int) -> Self: return self._reuse_series("gather_every", n=n, offset=offset) def mode(self, *, keep: ModeKeepStrategy) -> Self: - return self._reuse_series("mode", scalar_kwargs={"keep": keep}) + return self._reuse_series("mode", keep=keep) def is_finite(self) -> Self: return self._reuse_series("is_finite") @@ -730,11 +696,9 @@ def is_finite(self) -> Self: def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._reuse_series( "rolling_mean", - scalar_kwargs={ - "window_size": window_size, - "min_samples": min_samples, - "center": center, - }, + window_size=window_size, + min_samples=min_samples, + center=center, ) def rolling_std( @@ -742,22 +706,15 @@ def rolling_std( ) -> Self: return self._reuse_series( "rolling_std", - scalar_kwargs={ - "window_size": window_size, - "min_samples": min_samples, - "center": center, - "ddof": ddof, - }, + window_size=window_size, + min_samples=min_samples, + center=center, + ddof=ddof, ) def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._reuse_series( - "rolling_sum", - scalar_kwargs={ - "window_size": window_size, - "min_samples": min_samples, - "center": center, - }, + "rolling_sum", window_size=window_size, min_samples=min_samples, center=center ) def rolling_var( @@ -765,12 +722,10 @@ def rolling_var( ) -> Self: return self._reuse_series( "rolling_var", - scalar_kwargs={ - "window_size": window_size, - "min_samples": min_samples, - "center": center, - "ddof": ddof, - }, + window_size=window_size, + min_samples=min_samples, + center=center, + ddof=ddof, ) def map_batches( @@ -814,35 +769,31 @@ def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT]: return self._from_callable( func, - depth=self._depth + 1, - function_name=self._function_name + "->map_batches", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, context=self, ) def shift(self, n: int) -> Self: - return self._reuse_series("shift", scalar_kwargs={"n": n}) + return self._reuse_series("shift", n=n) def cum_sum(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_sum", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_sum", reverse=reverse) def cum_count(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_count", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_count", reverse=reverse) def cum_min(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_min", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_min", reverse=reverse) def cum_max(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_max", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_max", reverse=reverse) def cum_prod(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_prod", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_prod", reverse=reverse) def rank(self, method: RankMethod, *, descending: bool) -> Self: - return self._reuse_series( - "rank", scalar_kwargs={"method": method, "descending": descending} - ) + return self._reuse_series("rank", method=method, descending=descending) def log(self, base: float) -> Self: return self._reuse_series("log", base=base) @@ -854,28 +805,12 @@ def sqrt(self) -> Self: return self._reuse_series("sqrt") def is_between( - self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval + self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval ) -> Self: return self._reuse_series( "is_between", lower_bound=lower_bound, upper_bound=upper_bound, closed=closed ) - def is_close( - self, - other: Self | NumericLiteral, - *, - abs_tol: float, - rel_tol: float, - nans_equal: bool, - ) -> Self: - return self._reuse_series( - "is_close", - other=other, - abs_tol=abs_tol, - rel_tol=rel_tol, - nans_equal=nans_equal, - ) - def first(self) -> Self: return self._reuse_series("first", returns_scalar=True) @@ -1114,12 +1049,16 @@ class EagerExprStringNamespace( def len_chars(self) -> EagerExprT: return self.compliant._reuse_series_namespace("str", "len_chars") - def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> EagerExprT: + def replace( + self, value: EagerExprT, pattern: str, *, literal: bool, n: int + ) -> EagerExprT: return self.compliant._reuse_series_namespace( "str", "replace", pattern=pattern, value=value, literal=literal, n=n ) - def replace_all(self, pattern: str, value: str, *, literal: bool) -> EagerExprT: + def replace_all( + self, value: EagerExprT, pattern: str, *, literal: bool + ) -> EagerExprT: return self.compliant._reuse_series_namespace( "str", "replace_all", pattern=pattern, value=value, literal=literal ) diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index f9529cd442..cb62edf7ac 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from itertools import chain from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, TypeVar @@ -31,9 +30,6 @@ ) -_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)") - - def _evaluate_aliases( frame: CompliantFrameT, exprs: Iterable[ImplExpr[CompliantFrameT, Any]], / ) -> list[str]: @@ -170,7 +166,12 @@ def _remap_expr_name( @classmethod def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any: """Return the last function name in the chain defined by `expr`.""" - return _RE_LEAF_NAME.sub("", expr._function_name) + return next(expr._metadata.op_nodes_reversed()).name + + @classmethod + def _kwargs(cls, expr: DepthTrackingExprAny, /) -> dict[str, Any]: + """Return the last function kwargs in the chain defined by `expr`.""" + return next(expr._metadata.op_nodes_reversed()).kwargs class EagerGroupBy( diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 4cc7130828..6f7062be35 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -10,36 +10,31 @@ DepthTrackingExprT, EagerDataFrameT, EagerExprT, - EagerSeriesT, + EagerSeriesT_co, LazyExprT, NativeFrameT, NativeSeriesT, ) -from narwhals._expression_parsing import is_expr, is_series from narwhals._utils import ( exclude_column_names, get_column_names, passthrough_column_names, ) -from narwhals.dependencies import is_numpy_array, is_numpy_array_2d +from narwhals.dependencies import is_numpy_array_2d if TYPE_CHECKING: - from collections.abc import Container, Iterable, Sequence + from collections.abc import Iterable, Sequence from typing_extensions import TypeAlias, TypeIs from narwhals._compliant.selectors import CompliantSelectorNamespace - from narwhals._compliant.when_then import CompliantWhen, EagerWhen from narwhals._utils import Implementation, Version - from narwhals.expr import Expr - from narwhals.series import Series from narwhals.typing import ( ConcatMethod, Into1DArray, IntoDType, IntoSchema, NonNestedLiteral, - _1DArray, _2DArray, ) @@ -60,33 +55,20 @@ class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): @property def _expr(self) -> type[CompliantExprT]: ... - def parse_into_expr( - self, data: Expr | NonNestedLiteral | Any, /, *, str_as_lit: bool - ) -> CompliantExprT | NonNestedLiteral: - if is_expr(data): - expr = data._to_compliant_expr(self) - assert isinstance(expr, self._expr) # noqa: S101 - return expr - if isinstance(data, str) and not str_as_lit: - return self.col(data) - return data - # NOTE: `polars` def all(self) -> CompliantExprT: return self._expr.from_column_names(get_column_names, context=self) - def col(self, *column_names: str) -> CompliantExprT: - return self._expr.from_column_names( - passthrough_column_names(column_names), context=self - ) + def col(self, *names: str) -> CompliantExprT: + return self._expr.from_column_names(passthrough_column_names(names), context=self) - def exclude(self, excluded_names: Container[str]) -> CompliantExprT: + def exclude(self, *names: str) -> CompliantExprT: return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), context=self + partial(exclude_column_names, names=names), context=self ) - def nth(self, *column_indices: int) -> CompliantExprT: - return self._expr.from_column_indices(*column_indices, context=self) + def nth(self, indices: Sequence[int]) -> CompliantExprT: + return self._expr.from_column_indices(*indices, context=self) def len(self) -> CompliantExprT: ... def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ... @@ -103,9 +85,6 @@ def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... def concat( self, items: Iterable[CompliantFrameT], *, how: ConcatMethod ) -> CompliantFrameT: ... - def when( - self, predicate: CompliantExprT - ) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ... def concat_str( self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool ) -> CompliantExprT: ... @@ -124,20 +103,14 @@ class DepthTrackingNamespace( Protocol[CompliantFrameT, DepthTrackingExprT], ): def all(self) -> DepthTrackingExprT: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) + return self._expr.from_column_names(get_column_names, context=self) - def col(self, *column_names: str) -> DepthTrackingExprT: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) + def col(self, *names: str) -> DepthTrackingExprT: + return self._expr.from_column_names(passthrough_column_names(names), context=self) - def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT: + def exclude(self, *names: str) -> DepthTrackingExprT: return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, + partial(exclude_column_names, names=names), context=self ) @@ -163,7 +136,7 @@ def from_native(self, data: NativeFrameT | Any, /) -> CompliantLazyFrameT: class EagerNamespace( DepthTrackingNamespace[EagerDataFrameT, EagerExprT], - Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT], + Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT, NativeFrameT, NativeSeriesT], ): @property def _backend_version(self) -> tuple[int, ...]: @@ -172,10 +145,44 @@ def _backend_version(self) -> tuple[int, ...]: @property def _dataframe(self) -> type[EagerDataFrameT]: ... @property - def _series(self) -> type[EagerSeriesT]: ... - def when( - self, predicate: EagerExprT - ) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ... + def _series(self) -> type[EagerSeriesT_co]: ... + def _if_then_else( + self, + when: NativeSeriesT, + then: NativeSeriesT, + otherwise: NativeSeriesT | None = None, + ) -> NativeSeriesT: ... + def when_then( + self, predicate: EagerExprT, then: EagerExprT, otherwise: EagerExprT | None = None + ) -> EagerExprT: + def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT_co]: + predicate_s = df._evaluate_single_output_expr(predicate) + align = predicate_s._align_full_broadcast + + then_s = df._evaluate_single_output_expr(then) + if otherwise is None: + predicate_s, then_s = align(predicate_s, then_s) + result = self._if_then_else(predicate_s.native, then_s.native) + + if otherwise is None: + predicate_s, then_s = align(predicate_s, then_s) + result = self._if_then_else(predicate_s.native, then_s.native) + else: + otherwise_s = df._evaluate_single_output_expr(otherwise) + predicate_s, then_s, otherwise_s = align(predicate_s, then_s, otherwise_s) + result = self._if_then_else( + predicate_s.native, then_s.native, otherwise_s.native + ) + return [then_s._with_native(result)] + + return self._expr._from_callable( + func=func, + evaluate_output_names=getattr( + then, "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(then, "_alias_output_names", None), + context=predicate, + ) def is_native(self, obj: Any, /) -> TypeIs[NativeFrameT | NativeSeriesT]: return self._dataframe._is_native(obj) or self._series._is_native(obj) @@ -183,10 +190,10 @@ def is_native(self, obj: Any, /) -> TypeIs[NativeFrameT | NativeSeriesT]: @overload def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ... @overload - def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ... + def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT_co: ... def from_native( self, data: NativeFrameT | NativeSeriesT | Any, / - ) -> EagerDataFrameT | EagerSeriesT: + ) -> EagerDataFrameT | EagerSeriesT_co: if self._dataframe._is_native(data): return self._dataframe.from_native(data, context=self) if self._series._is_native(data): @@ -194,23 +201,8 @@ def from_native( msg = f"Unsupported type: {type(data).__name__!r}" raise TypeError(msg) - def parse_into_expr( - self, - data: Expr | Series[NativeSeriesT] | _1DArray | NonNestedLiteral, - /, - *, - str_as_lit: bool, - ) -> EagerExprT | NonNestedLiteral: - if not (is_series(data) or is_numpy_array(data)): - return super().parse_into_expr(data, str_as_lit=str_as_lit) - return self._expr._from_series( - data._compliant_series - if is_series(data) - else self._series.from_numpy(data, context=self) - ) - @overload - def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ... + def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT_co: ... @overload def from_numpy( @@ -222,7 +214,7 @@ def from_numpy( data: Into1DArray | _2DArray, /, schema: IntoSchema | Sequence[str] | None = None, - ) -> EagerDataFrameT | EagerSeriesT: + ) -> EagerDataFrameT | EagerSeriesT_co: if is_numpy_array_2d(data): return self._dataframe.from_numpy(data, schema=schema, context=self) return self._series.from_numpy(data, context=self) diff --git a/narwhals/_compliant/selectors.py b/narwhals/_compliant/selectors.py index 515d0ccd15..bb4f612d10 100644 --- a/narwhals/_compliant/selectors.py +++ b/narwhals/_compliant/selectors.py @@ -229,7 +229,7 @@ def _is_selector( ) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]: return isinstance(other, type(self)) - @overload + @overload # type: ignore[override] def __sub__(self, other: Self) -> Self: ... @overload def __sub__( @@ -255,7 +255,7 @@ def names(df: FrameT) -> Sequence[str]: return self.selectors._selector.from_callables(series, names, context=self) return self._to_expr() - other - @overload + @overload # type: ignore[override] def __or__(self, other: Self) -> Self: ... @overload def __or__( @@ -284,7 +284,7 @@ def names(df: FrameT) -> Sequence[str]: return self.selectors._selector.from_callables(series, names, context=self) return self._to_expr() | other - @overload + @overload # type: ignore[override] def __and__(self, other: Self) -> Self: ... @overload def __and__( diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index bca1d80bbc..434b2dc760 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -13,6 +13,7 @@ from narwhals._compliant.typing import ( CompliantSeriesT_co, EagerDataFrameAny, + EagerSeriesT, EagerSeriesT_co, NativeSeriesT, NativeSeriesT_co, @@ -324,17 +325,17 @@ class EagerSeriesDateTimeNamespace( # type: ignore[misc] ): ... -class EagerSeriesListNamespace( # type: ignore[misc] - _SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co], - ListNamespace[EagerSeriesT_co], - Protocol[EagerSeriesT_co, NativeSeriesT_co], +class EagerSeriesListNamespace( # pyright: ignore[reportInvalidTypeVarUse] + _SeriesNamespace[EagerSeriesT, NativeSeriesT_co], + ListNamespace[EagerSeriesT], + Protocol[EagerSeriesT, NativeSeriesT_co], ): ... -class EagerSeriesStringNamespace( # type: ignore[misc] - _SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co], - StringNamespace[EagerSeriesT_co], - Protocol[EagerSeriesT_co, NativeSeriesT_co], +class EagerSeriesStringNamespace( + _SeriesNamespace[EagerSeriesT, NativeSeriesT_co], + StringNamespace[EagerSeriesT], + Protocol[EagerSeriesT, NativeSeriesT_co], ): ... diff --git a/narwhals/_compliant/when_then.py b/narwhals/_compliant/when_then.py deleted file mode 100644 index bc4db69382..0000000000 --- a/narwhals/_compliant/when_then.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast - -from narwhals._compliant.expr import CompliantExpr -from narwhals._compliant.typing import ( - CompliantExprAny, - CompliantFrameAny, - CompliantSeriesOrNativeExprAny, - EagerDataFrameT, - EagerExprT, - EagerSeriesT, - LazyExprAny, - NativeSeriesT, -) - -if TYPE_CHECKING: - from collections.abc import Sequence - - from typing_extensions import Self, TypeAlias - - from narwhals._compliant.typing import EvalSeries - from narwhals._utils import Implementation, Version, _LimitedContext - from narwhals.typing import NonNestedLiteral - - -__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen"] - -ExprT = TypeVar("ExprT", bound=CompliantExprAny) -LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny) -SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny) -FrameT = TypeVar("FrameT", bound=CompliantFrameAny) - -Scalar: TypeAlias = Any -"""A native literal value.""" - -IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar" -"""Anything that is convertible into a `CompliantExpr`.""" - - -class CompliantWhen(Protocol[FrameT, SeriesT, ExprT]): - _condition: ExprT - _then_value: IntoExpr[SeriesT, ExprT] - _otherwise_value: IntoExpr[SeriesT, ExprT] | None - _implementation: Implementation - _version: Version - - @property - def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT, Self]]: ... - def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ... - def then( - self, value: IntoExpr[SeriesT, ExprT], / - ) -> CompliantThen[FrameT, SeriesT, ExprT, Self]: - return self._then.from_when(self, value) - - @classmethod - def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self: - obj = cls.__new__(cls) - obj._condition = condition - obj._then_value = None - obj._otherwise_value = None - obj._implementation = context._implementation - obj._version = context._version - return obj - - -WhenT_contra = TypeVar( - "WhenT_contra", bound=CompliantWhen[Any, Any, Any], contravariant=True -) - - -class CompliantThen( - CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT, ExprT, WhenT_contra] -): - _call: EvalSeries[FrameT, SeriesT] - _when_value: CompliantWhen[FrameT, SeriesT, ExprT] - _implementation: Implementation - _version: Version - - @classmethod - def from_when(cls, when: WhenT_contra, then: IntoExpr[SeriesT, ExprT], /) -> Self: - when._then_value = then - obj = cls.__new__(cls) - obj._call = when - obj._when_value = when - obj._evaluate_output_names = getattr( - then, "_evaluate_output_names", lambda _df: ["literal"] - ) - obj._alias_output_names = getattr(then, "_alias_output_names", None) - obj._implementation = when._implementation - obj._version = when._version - return obj - - def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT: - self._when_value._otherwise_value = otherwise - return cast("ExprT", self) - - -class EagerWhen( - CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT], - Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT], -): - def _if_then_else( - self, - when: NativeSeriesT, - then: NativeSeriesT, - otherwise: NativeSeriesT | NonNestedLiteral | Scalar, - /, - ) -> NativeSeriesT: ... - - def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]: - is_expr = self._condition._is_expr - when: EagerSeriesT = self._condition(df)[0] - then: EagerSeriesT - align = when._align_full_broadcast - - if is_expr(self._then_value): - then = self._then_value(df)[0] - else: - then = when.alias("literal")._from_scalar(self._then_value) - then._broadcast = True - - if is_expr(self._otherwise_value): - otherwise = self._otherwise_value(df)[0] - when, then, otherwise = align(when, then, otherwise) - result = self._if_then_else(when.native, then.native, otherwise.native) - else: - when, then = align(when, then) - result = self._if_then_else(when.native, then.native, self._otherwise_value) - return [then._with_native(result)] diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index c1b8645404..2f2a9e374b 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -5,7 +5,6 @@ import dask.dataframe as dd from narwhals._dask.utils import add_row_index, evaluate_exprs -from narwhals._expression_parsing import ExprKind from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name from narwhals._typing_compat import assert_never from narwhals._utils import ( @@ -19,6 +18,7 @@ parse_columns_to_drop, zip_strict, ) +from narwhals.exceptions import MultiOutputExpressionError from narwhals.typing import CompliantLazyFrame if TYPE_CHECKING: @@ -108,6 +108,13 @@ def _iter_columns(self) -> Iterator[dx.Series]: for _col, ser in self.native.items(): # noqa: PERF102 yield ser + def _evaluate_single_output_expr(self, obj: DaskExpr) -> dx.Series: + results = obj._call(self) + if len(results) != 1: # pragma: no cover + msg = "multi-output expressions not allowed in this context" + raise MultiOutputExpressionError(msg) + return results[0] + def with_columns(self, *exprs: DaskExpr) -> Self: new_series = evaluate_exprs(self, *exprs) return self._with_native(self.native.assign(**dict(new_series))) @@ -223,10 +230,10 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: return self._with_native(add_row_index(self.native, name)) plx = self.__narwhals_namespace__() columns = self.columns - const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL) + const_expr = plx.lit(1, dtype=None).alias(name).broadcast() row_index_expr = ( plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by) - - 1 + - plx.lit(1, dtype=None).broadcast() ) return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns)) @@ -483,11 +490,14 @@ def gather_every(self, n: int, offset: int) -> Self: n_bytes=8, columns=self.columns, prefix="row_index_" ) plx = self.__narwhals_namespace__() + offset_expr = plx.lit(offset, dtype=None).broadcast() + n_expr = plx.lit(n, dtype=None).broadcast() + zero_expr = plx.lit(0, dtype=None).broadcast() return ( self.with_row_index(row_index_token, order_by=None) .filter( - (plx.col(row_index_token) >= offset) - & ((plx.col(row_index_token) - offset) % n == 0) + (plx.col(row_index_token) >= offset_expr) + & ((plx.col(row_index_token) - offset_expr) % n_expr == zero_expr) ) .drop([row_index_token], strict=False) ) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 1eb409ae09..0bbb89b6d8 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Callable, Literal, cast +from typing import TYPE_CHECKING, Any, Callable, cast import pandas as pd @@ -11,10 +11,9 @@ from narwhals._dask.utils import ( add_row_index, align_series_full_broadcast, - maybe_evaluate_expr, narwhals_to_native_dtype, ) -from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases +from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._pandas_like.expr import window_kwargs_to_pandas_equivalent from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype from narwhals._utils import ( @@ -30,19 +29,20 @@ import dask.dataframe.dask_expr as dx from typing_extensions import Self - from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs + from narwhals._compliant.typing import ( + AliasNames, + EvalNames, + EvalSeries, + NarwhalsAggregation, + ) from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace - from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._utils import Version, _LimitedContext from narwhals.typing import ( FillNullStrategy, IntoDType, ModeKeepStrategy, - NonNestedLiteral, - NumericLiteral, RollingInterpolationMethod, - TemporalLiteral, ) @@ -56,21 +56,14 @@ def __init__( self, call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm] *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[DaskLazyFrame], alias_output_names: AliasNames | None, version: Version, - scalar_kwargs: ScalarKwargs | None = None, ) -> None: self._call = call - self._depth = depth - self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._scalar_kwargs = scalar_kwargs or {} - self._metadata: ExprMetadata | None = None def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: return self._call(df) @@ -80,7 +73,7 @@ def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover return DaskNamespace(version=self._version) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + def broadcast(self) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: # result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16 # that raised a KeyError for result[0] during collection. @@ -88,12 +81,9 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self.__class__( func, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) @classmethod @@ -103,7 +93,6 @@ def from_column_names( /, *, context: _LimitedContext, - function_name: str = "", ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: try: @@ -118,8 +107,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return cls( func, - depth=0, - function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, @@ -132,8 +119,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return cls( func, - depth=0, - function_name="nth", evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, version=context._version, @@ -144,15 +129,13 @@ def _with_callable( # First argument to `call` should be `dx.Series` call: Callable[..., dx.Series], /, - expr_name: str = "", - scalar_kwargs: ScalarKwargs | None = None, - **expressifiable_args: Self | Any, + **expressifiable_args: Self, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: native_results: list[dx.Series] = [] native_series_list = self._call(df) other_native_series = { - key: maybe_evaluate_expr(df, value) + key: df._evaluate_single_output_expr(value) for key, value in expressifiable_args.items() } for native_series in native_series_list: @@ -162,12 +145,9 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self.__class__( func, - depth=self._depth + 1, - function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, - scalar_kwargs=scalar_kwargs, ) def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: @@ -181,12 +161,9 @@ def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: ) return type(self)( call=self._call, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) def _with_binary( @@ -197,9 +174,7 @@ def _with_binary( *, reverse: bool = False, ) -> Self: - result = self._with_callable( - lambda expr, other: call(expr, other), name, other=other - ) + result = self._with_callable(lambda expr, other: call(expr, other), other=other) if reverse: result = result.alias("literal") return result @@ -230,19 +205,17 @@ def __truediv__(self, other: Any) -> Self: def __floordiv__(self, other: Any) -> Self: def _floordiv( - df: DaskLazyFrame, series: dx.Series, other: dx.Series | Any + df: DaskLazyFrame, series: dx.Series, other: dx.Series ) -> dx.Series: series, other = align_series_full_broadcast(df, series, other) return (series.__floordiv__(other)).where(other != 0, None) def func(df: DaskLazyFrame) -> list[dx.Series]: - other_series = maybe_evaluate_expr(df, other) + other_series = df._evaluate_single_output_expr(other) return [_floordiv(df, series, other_series) for series in self(df)] return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->__floordiv__", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, @@ -286,18 +259,17 @@ def __rtruediv__(self, other: Any) -> Self: def __rfloordiv__(self, other: Any) -> Self: def _rfloordiv( - df: DaskLazyFrame, series: dx.Series, other: dx.Series | Any + df: DaskLazyFrame, series: dx.Series, other: dx.Series ) -> dx.Series: series, other = align_series_full_broadcast(df, series, other) return (other.__floordiv__(series)).where(series != 0, None) def func(df: DaskLazyFrame) -> list[dx.Series]: - return [_rfloordiv(df, series, other) for series in self(df)] + other_native = df._evaluate_single_output_expr(other) + return [_rfloordiv(df, series, other_native) for series in self(df)] return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->__rfloordiv__", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, @@ -310,10 +282,10 @@ def __rmod__(self, other: Any) -> Self: return self._reverse_binary_op("__rmod__", lambda a, b: a % b, other) def __invert__(self) -> Self: - return self._with_callable(lambda expr: expr.__invert__(), "__invert__") + return self._with_callable(lambda expr: expr.__invert__()) def mean(self) -> Self: - return self._with_callable(lambda expr: expr.mean().to_series(), "mean") + return self._with_callable(lambda expr: expr.mean().to_series()) def median(self) -> Self: from narwhals.exceptions import InvalidOperationError @@ -325,36 +297,28 @@ def func(s: dx.Series) -> dx.Series: raise InvalidOperationError(msg) return s.median_approximate().to_series() - return self._with_callable(func, "median") + return self._with_callable(func) def min(self) -> Self: - return self._with_callable(lambda expr: expr.min().to_series(), "min") + return self._with_callable(lambda expr: expr.min().to_series()) def max(self) -> Self: - return self._with_callable(lambda expr: expr.max().to_series(), "max") + return self._with_callable(lambda expr: expr.max().to_series()) def std(self, *, ddof: int) -> Self: - return self._with_callable( - lambda expr: expr.std(ddof=ddof).to_series(), - "std", - scalar_kwargs={"ddof": ddof}, - ) + return self._with_callable(lambda expr: expr.std(ddof=ddof).to_series()) def var(self, *, ddof: int) -> Self: - return self._with_callable( - lambda expr: expr.var(ddof=ddof).to_series(), - "var", - scalar_kwargs={"ddof": ddof}, - ) + return self._with_callable(lambda expr: expr.var(ddof=ddof).to_series()) def skew(self) -> Self: - return self._with_callable(lambda expr: expr.skew().to_series(), "skew") + return self._with_callable(lambda expr: expr.skew().to_series()) def kurtosis(self) -> Self: - return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis") + return self._with_callable(lambda expr: expr.kurtosis().to_series()) def shift(self, n: int) -> Self: - return self._with_callable(lambda expr: expr.shift(n), "shift") + return self._with_callable(lambda expr: expr.shift(n)) def cum_sum(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover @@ -362,52 +326,48 @@ def cum_sum(self, *, reverse: bool) -> Self: msg = "`cum_sum(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable(lambda expr: expr.cumsum(), "cum_sum") + return self._with_callable(lambda expr: expr.cumsum()) def cum_count(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_count(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable( - lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count" - ) + return self._with_callable(lambda expr: (~expr.isna()).astype(int).cumsum()) def cum_min(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_min(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable(lambda expr: expr.cummin(), "cum_min") + return self._with_callable(lambda expr: expr.cummin()) def cum_max(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_max(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable(lambda expr: expr.cummax(), "cum_max") + return self._with_callable(lambda expr: expr.cummax()) def cum_prod(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_prod(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable(lambda expr: expr.cumprod(), "cum_prod") + return self._with_callable(lambda expr: expr.cumprod()) def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center - ).sum(), - "rolling_sum", + ).sum() ) def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center - ).mean(), - "rolling_mean", + ).mean() ) def rolling_var( @@ -417,8 +377,7 @@ def rolling_var( return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center - ).var(), - "rolling_var", + ).var() ) msg = "Dask backend only supports `ddof=1` for `rolling_var`" raise NotImplementedError(msg) @@ -430,52 +389,49 @@ def rolling_std( return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center - ).std(), - "rolling_std", + ).std() ) msg = "Dask backend only supports `ddof=1` for `rolling_std`" raise NotImplementedError(msg) def sum(self) -> Self: - return self._with_callable(lambda expr: expr.sum().to_series(), "sum") + return self._with_callable(lambda expr: expr.sum().to_series()) def count(self) -> Self: - return self._with_callable(lambda expr: expr.count().to_series(), "count") + return self._with_callable(lambda expr: expr.count().to_series()) def round(self, decimals: int) -> Self: - return self._with_callable(lambda expr: expr.round(decimals), "round") + return self._with_callable(lambda expr: expr.round(decimals)) def floor(self) -> Self: import dask.array as da - return self._with_callable(da.floor, "floor") + return self._with_callable(da.floor) def ceil(self) -> Self: import dask.array as da - return self._with_callable(da.ceil, "ceil") + return self._with_callable(da.ceil) def unique(self) -> Self: - return self._with_callable(lambda expr: expr.unique(), "unique") + return self._with_callable(lambda expr: expr.unique()) def drop_nulls(self) -> Self: - return self._with_callable(lambda expr: expr.dropna(), "drop_nulls") + return self._with_callable(lambda expr: expr.dropna()) def abs(self) -> Self: - return self._with_callable(lambda expr: expr.abs(), "abs") + return self._with_callable(lambda expr: expr.abs()) def all(self) -> Self: return self._with_callable( lambda expr: expr.all( axis=None, skipna=True, split_every=False, out=None - ).to_series(), - "all", + ).to_series() ) def any(self) -> Self: return self._with_callable( - lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(), - "any", + lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series() ) def fill_nan(self, value: float | None) -> Self: @@ -493,13 +449,10 @@ def func(expr: dx.Series) -> dx.Series: ) return expr.mask(mask, fill) # pyright: ignore[reportArgumentType] - return self._with_callable(func, "fill_nan") + return self._with_callable(func) def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: def func(expr: dx.Series) -> dx.Series: if value is not None: @@ -512,32 +465,37 @@ def func(expr: dx.Series) -> dx.Series: ) return res_ser - return self._with_callable(func, "fill_null") + return self._with_callable(func) - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: return self._with_callable( lambda expr, lower_bound, upper_bound: expr.clip( lower=lower_bound, upper=upper_bound ), - "clip", lower_bound=lower_bound, upper_bound=upper_bound, ) - def diff(self) -> Self: - return self._with_callable(lambda expr: expr.diff(), "diff") + def clip_lower(self, lower_bound: Self) -> Self: + return self._with_callable( + lambda expr, lower_bound: expr.clip(lower=lower_bound), + lower_bound=lower_bound, + ) - def n_unique(self) -> Self: + def clip_upper(self, upper_bound: Self) -> Self: return self._with_callable( - lambda expr: expr.nunique(dropna=False).to_series(), "n_unique" + lambda expr, upper_bound: expr.clip(upper=upper_bound), + upper_bound=upper_bound, ) + def diff(self) -> Self: + return self._with_callable(lambda expr: expr.diff()) + + def n_unique(self) -> Self: + return self._with_callable(lambda expr: expr.nunique(dropna=False).to_series()) + def is_null(self) -> Self: - return self._with_callable(lambda expr: expr.isna(), "is_null") + return self._with_callable(lambda expr: expr.isna()) def is_nan(self) -> Self: def func(expr: dx.Series) -> dx.Series: @@ -549,10 +507,10 @@ def func(expr: dx.Series) -> dx.Series: msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?" raise InvalidOperationError(msg) - return self._with_callable(func, "is_null") + return self._with_callable(func) def len(self) -> Self: - return self._with_callable(lambda expr: expr.size.to_series(), "len") + return self._with_callable(lambda expr: expr.size.to_series()) def quantile( self, quantile: float, interpolation: RollingInterpolationMethod @@ -567,11 +525,7 @@ def func(expr: dx.Series) -> dx.Series: q=quantile, method="dask" ).to_series() # pragma: no cover - return self._with_callable( - func, - "quantile", - scalar_kwargs={"quantile": quantile, "interpolation": "linear"}, - ) + return self._with_callable(func) msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead." raise NotImplementedError(msg) @@ -585,7 +539,7 @@ def func(expr: dx.Series) -> dx.Series: first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token] return frame[col_token].isin(first_distinct_index) - return self._with_callable(func, "is_first_distinct") + return self._with_callable(func) def is_last_distinct(self) -> Self: def func(expr: dx.Series) -> dx.Series: @@ -597,7 +551,7 @@ def func(expr: dx.Series) -> dx.Series: last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token] return frame[col_token].isin(last_distinct_index) - return self._with_callable(func, "is_last_distinct") + return self._with_callable(func) def is_unique(self) -> Self: def func(expr: dx.Series) -> dx.Series: @@ -609,15 +563,13 @@ def func(expr: dx.Series) -> dx.Series: == 1 ) - return self._with_callable(func, "is_unique") + return self._with_callable(func) def is_in(self, other: Any) -> Self: - return self._with_callable(lambda expr: expr.isin(other), "is_in") + return self._with_callable(lambda expr: expr.isin(other)) def null_count(self) -> Self: - return self._with_callable( - lambda expr: expr.isna().sum().to_series(), "null_count" - ) + return self._with_callable(lambda expr: expr.isna().sum().to_series()) def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: # pandas is a required dependency of dask so it's safe to import this @@ -642,7 +594,8 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: msg = "`over` with `order_by` is not yet supported in Dask." raise NotImplementedError(msg) else: - function_name = PandasLikeGroupBy._leaf_name(self) + leaf_node = next(self._metadata.op_nodes_reversed()) + function_name = cast("NarwhalsAggregation", leaf_node.name) try: dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name] except KeyError: @@ -653,7 +606,7 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: ) raise NotImplementedError(msg) from None dask_kwargs = window_kwargs_to_pandas_equivalent( - function_name, self._scalar_kwargs + function_name, leaf_node.kwargs ) def func(df: DaskLazyFrame) -> Sequence[dx.Series]: @@ -685,8 +638,6 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, @@ -697,12 +648,12 @@ def func(expr: dx.Series) -> dx.Series: native_dtype = narwhals_to_native_dtype(dtype, self._version) return expr.astype(native_dtype) - return self._with_callable(func, "cast") + return self._with_callable(func) def is_finite(self) -> Self: import dask.array as da - return self._with_callable(da.isfinite, "is_finite") + return self._with_callable(da.isfinite) def log(self, base: float) -> Self: import dask.array as da @@ -710,17 +661,17 @@ def log(self, base: float) -> Self: def _log(expr: dx.Series) -> dx.Series: return da.log(expr) / da.log(base) - return self._with_callable(_log, "log") + return self._with_callable(_log) def exp(self) -> Self: import dask.array as da - return self._with_callable(da.exp, "exp") + return self._with_callable(da.exp) def sqrt(self) -> Self: import dask.array as da - return self._with_callable(da.sqrt, "sqrt") + return self._with_callable(da.sqrt) def mode(self, *, keep: ModeKeepStrategy) -> Self: def func(expr: dx.Series) -> dx.Series: @@ -728,7 +679,7 @@ def func(expr: dx.Series) -> dx.Series: result = expr.to_frame().mode()[_name] return result.head(1) if keep == "any" else result - return self._with_callable(func, "mode", scalar_kwargs={"keep": keep}) + return self._with_callable(func) @property def str(self) -> DaskExprStringNamespace: @@ -738,8 +689,9 @@ def str(self) -> DaskExprStringNamespace: def dt(self) -> DaskExprDateTimeNamespace: return DaskExprDateTimeNamespace(self) - rank = not_implemented() + filter = not_implemented() first = not_implemented() + rank = not_implemented() last = not_implemented() # namespaces diff --git a/narwhals/_dask/expr_dt.py b/narwhals/_dask/expr_dt.py index a3e3f8eab6..bc75b99747 100644 --- a/narwhals/_dask/expr_dt.py +++ b/narwhals/_dask/expr_dt.py @@ -25,70 +25,59 @@ class DaskExprDateTimeNamespace( LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"] ): def date(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.date, "date") + return self.compliant._with_callable(lambda expr: expr.dt.date) def year(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.year, "year") + return self.compliant._with_callable(lambda expr: expr.dt.year) def month(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.month, "month") + return self.compliant._with_callable(lambda expr: expr.dt.month) def day(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.day, "day") + return self.compliant._with_callable(lambda expr: expr.dt.day) def hour(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour") + return self.compliant._with_callable(lambda expr: expr.dt.hour) def minute(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute") + return self.compliant._with_callable(lambda expr: expr.dt.minute) def second(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.second, "second") + return self.compliant._with_callable(lambda expr: expr.dt.second) def millisecond(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.microsecond // 1000, "millisecond" - ) + return self.compliant._with_callable(lambda expr: expr.dt.microsecond // 1000) def microsecond(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.microsecond, "microsecond" - ) + return self.compliant._with_callable(lambda expr: expr.dt.microsecond) def nanosecond(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond" + lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond ) def ordinal_day(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.dayofyear, "ordinal_day" - ) + return self.compliant._with_callable(lambda expr: expr.dt.dayofyear) def weekday(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.weekday + 1, # Dask is 0-6 - "weekday", + lambda expr: expr.dt.weekday + 1 # Dask is 0-6 ) def to_string(self, format: str) -> DaskExpr: return self.compliant._with_callable( - lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")), - "strftime", - format=format, + lambda expr: expr.dt.strftime(format.replace("%.f", ".%f")) ) def replace_time_zone(self, time_zone: str | None) -> DaskExpr: return self.compliant._with_callable( - lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone) + lambda expr: expr.dt.tz_localize(None).dt.tz_localize(time_zone) if time_zone is not None - else expr.dt.tz_localize(None), - "tz_localize", - time_zone=time_zone, + else expr.dt.tz_localize(None) ) def convert_time_zone(self, time_zone: str) -> DaskExpr: - def func(s: dx.Series, time_zone: str) -> dx.Series: + def func(s: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype( s.dtype, self.compliant._version, Implementation.DASK ) @@ -96,11 +85,11 @@ def func(s: dx.Series, time_zone: str) -> dx.Series: return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue] return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue] - return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone) + return self.compliant._with_callable(func) # ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808. def timestamp(self, time_unit: TimeUnit) -> DaskExpr: # pragma: no cover - def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series: + def func(s: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype( s.dtype, self.compliant._version, Implementation.DASK ) @@ -124,33 +113,27 @@ def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series: raise TypeError(msg) return result.where(~mask_na) # pyright: ignore[reportReturnType] - return self.compliant._with_callable(func, "datetime", time_unit=time_unit) + return self.compliant._with_callable(func) def total_minutes(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() // 60, "total_minutes" - ) + return self.compliant._with_callable(lambda expr: expr.dt.total_seconds() // 60) def total_seconds(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() // 1, "total_seconds" - ) + return self.compliant._with_callable(lambda expr: expr.dt.total_seconds() // 1) def total_milliseconds(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1, - "total_milliseconds", + lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1 ) def total_microseconds(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1, - "total_microseconds", + lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1 ) def total_nanoseconds(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds" + lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1 ) def truncate(self, every: str) -> DaskExpr: @@ -160,10 +143,10 @@ def truncate(self, every: str) -> DaskExpr: msg = f"Truncating to {unit} is not yet supported for dask." raise NotImplementedError(msg) freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}" - return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate") + return self.compliant._with_callable(lambda expr: expr.dt.floor(freq)) def offset_by(self, by: str) -> DaskExpr: - def func(s: dx.Series, by: str) -> dx.Series: + def func(s: dx.Series) -> dx.Series: interval = Interval.parse_no_constraints(by) unit = interval.unit if unit in {"y", "q", "mo", "d", "ns"}: @@ -172,4 +155,4 @@ def func(s: dx.Series, by: str) -> dx.Series: offset = interval.to_timedelta() return s.add(offset) - return self.compliant._with_callable(func, "offset_by", by=by) + return self.compliant._with_callable(func) diff --git a/narwhals/_dask/expr_str.py b/narwhals/_dask/expr_str.py index 200a358297..0a5e036e3c 100644 --- a/narwhals/_dask/expr_str.py +++ b/narwhals/_dask/expr_str.py @@ -16,111 +16,65 @@ class DaskExprStringNamespace(LazyExprNamespace["DaskExpr"], StringNamespace["DaskExpr"]): def len_chars(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.str.len(), "len") - - def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> DaskExpr: - def _replace( - expr: dx.Series, pattern: str, value: str, *, literal: bool, n: int - ) -> dx.Series: - try: - return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] - pattern, value, regex=not literal, n=n - ) - except TypeError as e: - if not isinstance(value, str): - msg = "dask backed `Expr.str.replace` only supports str replacement values" - raise TypeError(msg) from e - raise + return self.compliant._with_callable(lambda expr: expr.str.len()) - return self.compliant._with_callable( - _replace, "replace", pattern=pattern, value=value, literal=literal, n=n - ) + def replace( + self, value: DaskExpr, pattern: str, *, literal: bool, n: int + ) -> DaskExpr: + if not value._metadata.is_literal: + msg = "dask backed `Expr.str.replace` only supports str replacement values" + raise TypeError(msg) - def replace_all(self, pattern: str, value: str, *, literal: bool) -> DaskExpr: - def _replace_all( - expr: dx.Series, pattern: str, value: str, *, literal: bool - ) -> dx.Series: - try: - return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] - pattern, value, regex=not literal, n=-1 - ) - except TypeError as e: - if not isinstance(value, str): - msg = "dask backed `Expr.str.replace_all` only supports str replacement values." - raise TypeError(msg) from e - raise + def _replace(expr: dx.Series, value: dx.Series) -> dx.Series: + # OK to call `compute` here as `value` is just a literal expression. + return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] + pattern, value.compute(), regex=not literal, n=n + ) - return self.compliant._with_callable( - _replace_all, "replace", pattern=pattern, value=value, literal=literal - ) + return self.compliant._with_callable(_replace, value=value) + + def replace_all(self, value: DaskExpr, pattern: str, *, literal: bool) -> DaskExpr: + return self.replace(value, pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, characters: expr.str.strip(characters), - "strip", - characters=characters, - ) + return self.compliant._with_callable(lambda expr: expr.str.strip(characters)) def starts_with(self, prefix: str) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, prefix: expr.str.startswith(prefix), "starts_with", prefix=prefix - ) + return self.compliant._with_callable(lambda expr: expr.str.startswith(prefix)) def ends_with(self, suffix: str) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, suffix: expr.str.endswith(suffix), "ends_with", suffix=suffix - ) + return self.compliant._with_callable(lambda expr: expr.str.endswith(suffix)) def contains(self, pattern: str, *, literal: bool) -> DaskExpr: return self.compliant._with_callable( - lambda expr, pattern, literal: expr.str.contains( - pat=pattern, regex=not literal - ), - "contains", - pattern=pattern, - literal=literal, + lambda expr: expr.str.contains(pat=pattern, regex=not literal) ) def slice(self, offset: int, length: int | None) -> DaskExpr: return self.compliant._with_callable( - lambda expr, offset, length: expr.str.slice( + lambda expr: expr.str.slice( start=offset, stop=offset + length if length else None - ), - "slice", - offset=offset, - length=length, + ) ) def split(self, by: str) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, by: expr.str.split(pat=by), "split", by=by - ) + return self.compliant._with_callable(lambda expr: expr.str.split(pat=by)) def to_datetime(self, format: str | None) -> DaskExpr: return self.compliant._with_callable( - lambda expr, format: dd.to_datetime(expr, format=format), - "to_datetime", - format=format, + lambda expr: dd.to_datetime(expr, format=format) ) def to_uppercase(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.str.upper(), "to_uppercase" - ) + return self.compliant._with_callable(lambda expr: expr.str.upper()) def to_lowercase(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.str.lower(), "to_lowercase" - ) + return self.compliant._with_callable(lambda expr: expr.str.lower()) def to_titlecase(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.str.title(), "to_titlecase" - ) + return self.compliant._with_callable(lambda expr: expr.str.title()) def zfill(self, width: int) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, width: expr.str.zfill(width), "zfill", width=width - ) + return self.compliant._with_callable(lambda expr: expr.str.zfill(width)) to_date = not_implemented() diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index e280dcf1a3..3eb47a2ec3 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -126,17 +126,18 @@ def agg(self, *exprs: DaskExpr) -> DaskLazyFrame: output_names, aliases = evaluate_output_names_and_aliases( expr, self.compliant, exclude ) - if expr._depth == 0: + last_node = next(expr._metadata.op_nodes_reversed()) + if len(list(expr._metadata.op_nodes_reversed())) == 1: # e.g. `agg(nw.len())` column = self._keys[0] - agg_fn = self._remap_expr_name(expr._function_name) + agg_fn = self._remap_expr_name(last_node.name) simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn))) continue # e.g. `agg(nw.mean('a'))` agg_fn = self._remap_expr_name(self._leaf_name(expr)) # deal with n_unique case in a "lazy" mode to not depend on dask globally - agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn + agg_fn = agg_fn(**last_node.kwargs) if callable(agg_fn) else agg_fn simple_aggregations.update( (alias, (output_name, agg_fn)) for alias, output_name in zip_strict(aliases, output_names) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index df66a583c4..c4791a7e0d 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from datetime import date, datetime from functools import reduce from itertools import chain from typing import TYPE_CHECKING, cast @@ -8,12 +9,7 @@ import dask.dataframe as dd import pandas as pd -from narwhals._compliant import ( - CompliantThen, - CompliantWhen, - DepthTrackingNamespace, - LazyNamespace, -) +from narwhals._compliant import DepthTrackingNamespace, LazyNamespace from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.selectors import DaskSelectorNamespace @@ -23,18 +19,16 @@ validate_comparand, ) from narwhals._expression_parsing import ( - ExprKind, combine_alias_output_names, combine_evaluate_output_names, ) from narwhals._utils import Implementation, zip_strict if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence + from collections.abc import Iterable, Iterator import dask.dataframe.dask_expr as dx - from narwhals._compliant.typing import ScalarKwargs from narwhals._utils import Version from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral @@ -65,6 +59,13 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: if dtype is not None: native_dtype = narwhals_to_native_dtype(dtype, self._version) native_pd_series = pd.Series([value], dtype=native_dtype, name="literal") + elif isinstance(value, date) and not isinstance( + value, datetime + ): # pragma: no cover + # Dask auto-infers this as object type, which causes issues down the line. + # This shows up in TPC-H q8. + native_dtype = "date32[pyarrow]" + native_pd_series = pd.Series([value], dtype=native_dtype, name="literal") else: native_pd_series = pd.Series([value], name="literal") npartitions = df._native_frame.npartitions @@ -73,8 +74,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( func, - depth=0, - function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, version=self._version, @@ -87,8 +86,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( func, - depth=0, - function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, version=self._version, @@ -106,8 +103,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -122,8 +117,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -138,8 +131,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -188,8 +179,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -205,8 +194,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -222,16 +209,11 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, ) - def when(self, predicate: DaskExpr) -> DaskWhen: - return DaskWhen.from_expr(predicate, context=self) - def concat_str( self, *exprs: DaskExpr, separator: str, ignore_nulls: bool ) -> DaskExpr: @@ -266,8 +248,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="concat_str", evaluate_output_names=getattr( exprs[0], "_evaluate_output_names", lambda _df: ["literal"] ), @@ -284,55 +264,55 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="coalesce", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, ) + def when_then( + self, predicate: DaskExpr, then: DaskExpr, otherwise: DaskExpr | None = None + ) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + then_value = df._evaluate_single_output_expr(then) + otherwise_value = ( + df._evaluate_single_output_expr(otherwise) + if otherwise is not None + else otherwise + ) -class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): # pyright: ignore[reportInvalidTypeArguments] - @property - def _then(self) -> type[DaskThen]: - return DaskThen - - def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: - then_value = ( - self._then_value(df)[0] - if isinstance(self._then_value, DaskExpr) - else self._then_value - ) - otherwise_value = ( - self._otherwise_value(df)[0] - if isinstance(self._otherwise_value, DaskExpr) - else self._otherwise_value - ) - - condition = self._condition(df)[0] - # re-evaluate DataFrame if the condition aggregates to force - # then/otherwise to be evaluated against the aggregated frame - assert self._condition._metadata is not None # noqa: S101 - if self._condition._metadata.is_scalar_like: - new_df = df._with_native(condition.to_frame()) - condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0] - df = new_df - - if self._otherwise_value is None: - (condition, then_series) = align_series_full_broadcast( - df, condition, then_value + condition = df._evaluate_single_output_expr(predicate) + # re-evaluate DataFrame if the condition aggregates to force + # then/otherwise to be evaluated against the aggregated frame + if all( + x._metadata.is_scalar_like + for x in ( + (predicate, then) + if otherwise is None + else (predicate, then, otherwise) + ) + ): + new_df = df._with_native(condition.to_frame()) + condition = df._evaluate_single_output_expr(predicate.broadcast()) + df = new_df + + if otherwise is None: + (condition, then_series) = align_series_full_broadcast( + df, condition, then_value + ) + validate_comparand(condition, then_series) + return [then_series.where(condition)] # pyright: ignore[reportArgumentType] + (condition, then_series, otherwise_series) = align_series_full_broadcast( + df, condition, then_value, otherwise_value ) validate_comparand(condition, then_series) - return [then_series.where(condition)] # pyright: ignore[reportArgumentType] - (condition, then_series, otherwise_series) = align_series_full_broadcast( - df, condition, then_value, otherwise_value - ) - validate_comparand(condition, then_series) - validate_comparand(condition, otherwise_series) - return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType] + validate_comparand(condition, otherwise_series) + return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType] - -class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): # pyright: ignore[reportInvalidTypeArguments] - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "whenthen" + return self._expr( + call=func, + evaluate_output_names=getattr( + then, "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(then, "_alias_output_names", None), + version=self._version, + ) diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 9fb6eeecb8..501662422d 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: import dask.dataframe.dask_expr as dx # noqa: F401 - from narwhals._compliant.typing import ScalarKwargs from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401 @@ -19,15 +18,9 @@ def _selector(self) -> type[DaskSelector]: class DaskSelector(CompliantSelector["DaskLazyFrame", "dx.Series"], DaskExpr): # type: ignore[misc] - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "selector" - def _to_expr(self) -> DaskExpr: return DaskExpr( self._call, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 0827864c7f..0f1eee32a6 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -23,16 +23,6 @@ import dask_expr as dx -def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object: - from narwhals._dask.expr import DaskExpr - - if isinstance(obj, DaskExpr): - results = obj._call(df) - assert len(results) == 1 # debug assertion # noqa: S101 - return results[0] - return obj - - def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]: native_results: list[tuple[str, dx.Series]] = [] for expr in exprs: diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 42d5117946..d9ad44e4e0 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -12,7 +12,7 @@ F, catch_duckdb_exception, col, - evaluate_exprs, + evaluate_exprs_and_aliases, join_column_names, lit, native_to_narwhals_dtype, @@ -25,7 +25,6 @@ Version, extend_bool, generate_temporary_column_name, - not_implemented, parse_columns_to_drop, requires, zip_strict, @@ -177,14 +176,18 @@ def simple_select(self, *column_names: str) -> Self: return self._with_native(self.native.select(*column_names)) def aggregate(self, *exprs: DuckDBExpr) -> Self: - selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)] + selection = [ + val.alias(name) for name, val in evaluate_exprs_and_aliases(self, *exprs) + ] try: return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type] except Exception as e: # noqa: BLE001 raise catch_duckdb_exception(e, self) from None def select(self, *exprs: DuckDBExpr) -> Self: - selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs)) + selection = ( + val.alias(name) for name, val in evaluate_exprs_and_aliases(self, *exprs) + ) try: return self._with_native(self.native.select(*selection)) except Exception as e: # noqa: BLE001 @@ -206,7 +209,7 @@ def lazy(self, backend: None = None, **_: None) -> Self: return self def with_columns(self, *exprs: DuckDBExpr) -> Self: - new_columns_map = dict(evaluate_exprs(self, *exprs)) + new_columns_map = dict(evaluate_exprs_and_aliases(self, *exprs)) result = [ new_columns_map.pop(name).alias(name) if name in new_columns_map @@ -224,8 +227,8 @@ def filter(self, predicate: DuckDBExpr) -> Self: mask = predicate(self)[0] try: return self._with_native(self.native.filter(mask)) - except Exception as e: # noqa: BLE001 - raise catch_duckdb_exception(e, self) from None + except Exception as e: + raise catch_duckdb_exception(e, self) from e @property def schema(self) -> dict[str, DType]: @@ -480,9 +483,8 @@ def explode(self, columns: Sequence[str]) -> Self: rel = self.native original_columns = self.columns - not_null_condition = col_to_explode.isnotnull() & F("len", col_to_explode) > lit( - 0 - ) + zero = lit(0) + not_null_condition = col_to_explode.isnotnull() & F("len", col_to_explode) > zero non_null_rel = rel.filter(not_null_condition).select( *( F("unnest", col_to_explode).alias(name) if name in columns else name @@ -550,10 +552,3 @@ def sink_parquet(self, file: str | Path | BytesIO) -> None: (FORMAT parquet) """ # noqa: S608 duckdb.sql(query) - - gather_every = not_implemented.deprecated( - "`LazyFrame.gather_every` is deprecated and will be removed in a future version." - ) - tail = not_implemented.deprecated( - "`LazyFrame.tail` is deprecated and will be removed in a future version." - ) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index bdbd8d5f6f..58ee4a7977 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Callable, Literal, cast +from typing import TYPE_CHECKING, Any, Callable, cast from duckdb import CoalesceOperator, StarExpression @@ -20,7 +20,6 @@ when, window_expression, ) -from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._sql.expr import SQLExpr from narwhals._utils import Implementation, Version, extend_bool @@ -40,12 +39,7 @@ from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.namespace import DuckDBNamespace from narwhals._utils import _LimitedContext - from narwhals.typing import ( - FillNullStrategy, - IntoDType, - NonNestedLiteral, - RollingInterpolationMethod, - ) + from narwhals.typing import FillNullStrategy, IntoDType, RollingInterpolationMethod DuckDBWindowFunction = WindowFunction[DuckDBLazyFrame, Expression] DuckDBWindowInputs = WindowInputs[Expression] @@ -68,7 +62,6 @@ def __init__( self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._metadata: ExprMetadata | None = None self._window_function: DuckDBWindowFunction | None = window_function def _count_star(self) -> Expression: @@ -116,12 +109,7 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover return DuckDBNamespace(version=self._version) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: - if kind is ExprKind.LITERAL: - return self - if self._backend_version < (1, 3): - msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns." - raise NotImplementedError(msg) + def broadcast(self) -> Self: return self.over([lit(1)], []) @classmethod @@ -231,10 +219,7 @@ def is_in(self, other: Sequence[Any]) -> Self: return self._with_elementwise(lambda expr: F("contains", lit(other), expr)) def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: if strategy is not None: if self._backend_version < (1, 3): # pragma: no cover @@ -267,6 +252,7 @@ def _fill_with_strategy( def _fill_constant(expr: Expression, value: Any) -> Expression: return CoalesceOperator(expr, value) + assert value is not None # noqa: S101 return self._with_elementwise(_fill_constant, value=value) def cast(self, dtype: IntoDType) -> Self: diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 09b5ecd8eb..ebc5041e68 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -26,7 +26,6 @@ combine_evaluate_output_names, ) from narwhals._sql.namespace import SQLNamespace -from narwhals._sql.when_then import SQLThen, SQLWhen from narwhals._utils import Implementation if TYPE_CHECKING: @@ -34,6 +33,7 @@ from duckdb import DuckDBPyRelation # noqa: F401 + from narwhals._compliant.window import WindowInputs from narwhals._utils import Version from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral @@ -130,9 +130,6 @@ def func(cols: Iterable[Expression]) -> Expression: return self._expr._from_elementwise_horizontal_op(func, *exprs) - def when(self, predicate: DuckDBExpr) -> DuckDBWhen: - return DuckDBWhen.from_expr(predicate, context=self) - def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[Expression]: tz = DeferredTimeZone(df.native) @@ -141,8 +138,14 @@ def func(df: DuckDBLazyFrame) -> list[Expression]: return [lit(value).cast(target)] return [lit(value)] + def window_func( + df: DuckDBLazyFrame, _window_inputs: WindowInputs[Expression] + ) -> list[Expression]: + return func(df) + return self._expr( func, + window_func, evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, version=self._version, @@ -158,12 +161,3 @@ def func(_df: DuckDBLazyFrame) -> list[Expression]: alias_output_names=None, version=self._version, ) - - -class DuckDBWhen(SQLWhen["DuckDBLazyFrame", Expression, DuckDBExpr]): - @property - def _then(self) -> type[DuckDBThen]: - return DuckDBThen - - -class DuckDBThen(SQLThen["DuckDBLazyFrame", Expression, DuckDBExpr], DuckDBExpr): ... diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 6304ba86d7..37d86a88ba 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -87,12 +87,12 @@ def concat_str(*exprs: Expression, separator: str = "") -> Expression: return F("concat_ws", lit(separator), *exprs) if separator else F("concat", *exprs) -def evaluate_exprs( +def evaluate_exprs_and_aliases( df: DuckDBLazyFrame, /, *exprs: DuckDBExpr ) -> list[tuple[str, Expression]]: native_results: list[tuple[str, Expression]] = [] for expr in exprs: - native_series_list = expr._call(df) + native_series_list = expr(df) output_names = expr._evaluate_output_names(df) if expr._alias_output_names is not None: output_names = expr._alias_output_names(output_names) @@ -406,7 +406,6 @@ def sql_expression(expr: str) -> Expression: "col", "concat_str", "duckdb_dtypes", - "evaluate_exprs", "fetch_rel_time_zone", "function", "generate_order_by_sql", diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index b5cbd85f6b..6f1c3205b9 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -5,15 +5,18 @@ from __future__ import annotations from enum import Enum, auto -from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Literal, cast -from narwhals._utils import is_compliant_expr, zip_strict -from narwhals.dependencies import is_narwhals_series, is_numpy_array, is_numpy_array_1d -from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError +from narwhals._utils import zip_strict +from narwhals.dependencies import is_numpy_array_1d +from narwhals.exceptions import ( + InvalidIntoExprError, + InvalidOperationError, + MultiOutputExpressionError, +) if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterator, Sequence from typing_extensions import Never, TypeIs @@ -29,8 +32,6 @@ from narwhals.series import Series from narwhals.typing import IntoExpr, NonNestedLiteral, _1DArray - T = TypeVar("T") - def is_expr(obj: Any) -> TypeIs[Expr]: """Check whether `obj` is a Narwhals Expr.""" @@ -46,22 +47,11 @@ def is_series(obj: Any) -> TypeIs[Series[Any]]: return isinstance(obj, Series) -def is_into_expr_eager(obj: Any) -> TypeIs[Expr | Series[Any] | str | _1DArray]: - from narwhals.expr import Expr - from narwhals.series import Series - - return isinstance(obj, (Series, Expr, str)) or is_numpy_array_1d(obj) - - def combine_evaluate_output_names( *exprs: CompliantExpr[CompliantFrameT, Any], ) -> EvalNames[CompliantFrameT]: # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the # first name of `expr1`. - if not is_compliant_expr(exprs[0]): # pragma: no cover - msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug." - raise AssertionError(msg) - def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]: return exprs[0]._evaluate_output_names(df)[:1] @@ -89,16 +79,14 @@ def evaluate_output_names_and_aliases( if expr._alias_output_names is None else expr._alias_output_names(output_names) ) - if exclude: - assert expr._metadata is not None # noqa: S101 - if expr._metadata.expansion_kind.is_multi_unnamed(): - output_names, aliases = zip_strict( - *[ - (x, alias) - for x, alias in zip_strict(output_names, aliases) - if x not in exclude - ] - ) + if exclude and expr._metadata.expansion_kind.is_multi_unnamed(): + output_names, aliases = zip_strict( + *[ + (x, alias) + for x, alias in zip_strict(output_names, aliases) + if x not in exclude + ] + ) return output_names, aliases @@ -132,47 +120,57 @@ class ExprKind(Enum): OVER = auto() """Results from calling `.over` on expression.""" - UNKNOWN = auto() - """Based on the information we have, we can't determine the ExprKind.""" + COL = auto() + """Results from calling `nw.col`.""" - @property - def is_scalar_like(self) -> bool: - return self in {ExprKind.LITERAL, ExprKind.AGGREGATION} + NTH = auto() + """Results from calling `nw.nth`.""" - @property - def is_orderable_window(self) -> bool: - return self in {ExprKind.ORDERABLE_WINDOW, ExprKind.ORDERABLE_AGGREGATION} + EXCLUDE = auto() + """Results from calling `nw.exclude`.""" - @classmethod - def from_expr(cls, obj: Expr) -> ExprKind: - meta = obj._metadata - if meta.is_literal: - return ExprKind.LITERAL - if meta.is_scalar_like: - return ExprKind.AGGREGATION - if meta.is_elementwise: - return ExprKind.ELEMENTWISE - return ExprKind.UNKNOWN + ALL = auto() + """Results from calling `nw.all`.""" - @classmethod - def from_into_expr( - cls, obj: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool - ) -> ExprKind: - if is_expr(obj): - return cls.from_expr(obj) - if ( - is_narwhals_series(obj) - or is_numpy_array(obj) - or (isinstance(obj, str) and not str_as_lit) - ): - return ExprKind.ELEMENTWISE - return ExprKind.LITERAL + SELECTOR = auto() + """Results from creating an expression with a selector.""" + WHEN_THEN = auto() + """Results from `when/then expression`, possibly followed by `otherwise`.""" -def is_scalar_like( - obj: ExprKind, -) -> TypeIs[Literal[ExprKind.LITERAL, ExprKind.AGGREGATION]]: - return obj.is_scalar_like + SERIES = auto() + """Results from converting a Series to Expr.""" + + @property + def is_orderable(self) -> bool: + # Any operation which may be affected by `order_by`, such as `cum_sum`, + # `diff`, `rank`, `arg_max`, ... + return self in { + ExprKind.ORDERABLE_WINDOW, + ExprKind.WINDOW, + ExprKind.ORDERABLE_AGGREGATION, + ExprKind.ORDERABLE_FILTRATION, + } + + @property + def is_elementwise(self) -> bool: + # Any operation which can operate on each row independently + # of the rows around it, e.g. `abs(), __add__, sum_horizontal, ...` + return self in { + ExprKind.ALL, + ExprKind.COL, + ExprKind.ELEMENTWISE, + ExprKind.EXCLUDE, + ExprKind.LITERAL, + ExprKind.NTH, + ExprKind.SELECTOR, + ExprKind.SERIES, + ExprKind.WHEN_THEN, + } + + +def is_scalar_like(obj: CompliantExprAny) -> bool: + return obj._metadata.is_scalar_like class ExpansionKind(Enum): @@ -202,6 +200,126 @@ def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]: raise AssertionError(msg) # pragma: no cover +class ExprNode: + """An operation to create or modify an expression. + + Parameters: + kind: ExprKind of operation. + name: Name of function, as defined in the compliant protocols. + exprs: Expressifiable arguments to function. + str_as_lit: Whether to interpret strings as literals when they + are present in `exprs`. + allow_multi_output: Whether to allow any of `exprs` to be multi-output. + kwargs: Other (non-expressifiable) arguments to function. + """ + + def __init__( + self, + kind: ExprKind, + name: str, + /, + *exprs: IntoExpr | NonNestedLiteral, + str_as_lit: bool = False, + allow_multi_output: bool = False, + **kwargs: Any, + ) -> None: + self.kind: ExprKind = kind + self.name: str = name + self.exprs: Sequence[IntoExpr | NonNestedLiteral] = exprs + self.kwargs: dict[str, Any] = kwargs + self.str_as_lit: bool = str_as_lit + self.allow_multi_output: bool = allow_multi_output + + # Cached methods. + self._is_orderable_cached: bool | None = None + self._is_elementwise_cached: bool | None = None + + def __repr__(self) -> str: + if self.name == "col": + names = ", ".join(str(x) for x in self.kwargs["names"]) + return f"col({names})" + arg_str = [] + expr_repr = ", ".join(str(x) for x in self.exprs) + kwargs_repr = ", ".join(f"{key}={value}" for key, value in self.kwargs.items()) + if self.exprs: + arg_str.append(expr_repr) + if self.kwargs: + arg_str.append(kwargs_repr) + return f"{self.name}({', '.join(arg_str)})" + + def as_dict(self) -> dict[str, Any]: # pragma: no cover + # Just for debugging. + return { + "kind": self.kind, + "name": self.name, + "exprs": self.exprs, + "kwargs": self.kwargs, + "str_as_lit": self.str_as_lit, + "allow_multi_output": self.allow_multi_output, + } + + def _with_kwargs(self, **kwargs: Any) -> ExprNode: + return self.__class__( + self.kind, self.name, *self.exprs, str_as_lit=self.str_as_lit, **kwargs + ) + + def _push_down_over_node_in_place( + self, over_node: ExprNode, over_node_without_order_by: ExprNode + ) -> None: + exprs: list[IntoExpr | NonNestedLiteral] = [] + # Note: please keep this as a for-loop (rather than a list-comprehension) + # so that pytest-cov highlights any uncovered branches. + over_node_order_by = over_node.kwargs["order_by"] + over_node_partition_by = over_node.kwargs["partition_by"] + for expr in self.exprs: + if not is_expr(expr): + exprs.append(expr) + elif over_node_order_by and any( + expr_node.is_orderable() for expr_node in expr._nodes + ): + exprs.append(expr._with_over_node(over_node)) + elif over_node_partition_by and not all( + expr_node.is_elementwise() for expr_node in expr._nodes + ): + exprs.append(expr._with_over_node(over_node_without_order_by)) + else: + # If there's no `partition_by`, then `over_node_without_order_by` is a no-op. + exprs.append(expr) + self.exprs = exprs + + def is_orderable(self) -> bool: + if self._is_orderable_cached is None: + # Note: don't combine these if/then statements so that pytest-cov shows if + # anything is uncovered. + if self.kind.is_orderable: # noqa: SIM114 + self._is_orderable_cached = True + elif any( + any(node.is_orderable() for node in expr._nodes) + for expr in self.exprs + if is_expr(expr) + ): + self._is_orderable_cached = True + else: + self._is_orderable_cached = False + return self._is_orderable_cached + + def is_elementwise(self) -> bool: + if self._is_elementwise_cached is None: + # Note: don't combine these if/then statements so that pytest-cov shows if + # anything is uncovered. + if not self.kind.is_elementwise: # noqa: SIM114 + self._is_elementwise_cached = False + elif any( + any(not node.is_elementwise() for node in expr._nodes) + for expr in self.exprs + if is_expr(expr) + ): + self._is_elementwise_cached = False + else: + self._is_elementwise_cached = True + return self._is_elementwise_cached + + class ExprMetadata: """Expression metadata. @@ -212,28 +330,29 @@ class ExprMetadata: of the other rows around it. is_literal: Whether it is just a literal wrapped in an expression. is_scalar_like: Whether it is a literal or an aggregation. - last_node: The ExprKind of the last node. n_orderable_ops: The number of order-dependent operations. In the lazy case, this number must be `0` by the time the expression is evaluated. preserves_length: Whether the expression preserves the input length. + current_node: The current ExprNode in the linked list. + prev: Reference to the previous ExprMetadata in the linked list (None for root). """ __slots__ = ( + "current_node", "expansion_kind", "has_windows", "is_elementwise", "is_literal", "is_scalar_like", - "last_node", "n_orderable_ops", "preserves_length", + "prev", ) def __init__( self, expansion_kind: ExpansionKind, - last_node: ExprKind, *, has_windows: bool = False, n_orderable_ops: int = 0, @@ -241,93 +360,187 @@ def __init__( is_elementwise: bool = True, is_scalar_like: bool = False, is_literal: bool = False, + current_node: ExprNode, + prev: ExprMetadata | None = None, ) -> None: if is_literal: assert is_scalar_like # noqa: S101 # debug assertion if is_elementwise: assert preserves_length # noqa: S101 # debug assertion self.expansion_kind: ExpansionKind = expansion_kind - self.last_node: ExprKind = last_node self.has_windows: bool = has_windows self.n_orderable_ops: int = n_orderable_ops self.is_elementwise: bool = is_elementwise self.preserves_length: bool = preserves_length self.is_scalar_like: bool = is_scalar_like self.is_literal: bool = is_literal + self.current_node: ExprNode = current_node + self.prev: ExprMetadata | None = prev def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover msg = f"Cannot subclass {cls.__name__!r}" raise TypeError(msg) def __repr__(self) -> str: # pragma: no cover + nodes = tuple(reversed(tuple(self.iter_nodes_reversed()))) return ( f"ExprMetadata(\n" f" expansion_kind: {self.expansion_kind},\n" - f" last_node: {self.last_node},\n" f" has_windows: {self.has_windows},\n" f" n_orderable_ops: {self.n_orderable_ops},\n" f" is_elementwise: {self.is_elementwise},\n" f" preserves_length: {self.preserves_length},\n" f" is_scalar_like: {self.is_scalar_like},\n" f" is_literal: {self.is_literal},\n" + f" nodes: {nodes},\n" ")" ) + def iter_nodes_reversed(self) -> Iterator[ExprNode]: + """Iterate through all nodes from current to root.""" + current: ExprMetadata | None = self + while current is not None: + yield current.current_node + current = current.prev + + @classmethod + def from_node( + cls, node: ExprNode, *compliant_exprs: CompliantExprAny + ) -> ExprMetadata: + return KIND_TO_METADATA_CONSTRUCTOR[node.kind](node, *compliant_exprs) + + def with_node( + self, + node: ExprNode, + compliant_expr: CompliantExprAny, + *compliant_expr_args: CompliantExprAny, + ) -> ExprMetadata: + return KIND_TO_METADATA_UPDATER[node.kind]( + self, node, compliant_expr, *compliant_expr_args + ) + + @classmethod + def from_aggregation(cls, node: ExprNode) -> ExprMetadata: + return cls( + ExpansionKind.SINGLE, + is_elementwise=False, + preserves_length=False, + is_scalar_like=True, + current_node=node, + prev=None, + ) + + @classmethod + def from_literal(cls, node: ExprNode) -> ExprMetadata: + return cls( + ExpansionKind.SINGLE, + is_elementwise=False, + preserves_length=False, + is_literal=True, + is_scalar_like=True, + current_node=node, + prev=None, + ) + + @classmethod + def from_series(cls, node: ExprNode) -> ExprMetadata: + return cls(ExpansionKind.SINGLE, current_node=node, prev=None) + + @classmethod + def from_col(cls, node: ExprNode) -> ExprMetadata: + # e.g. `nw.col('a')`, `nw.nth(0)` + return ( + cls(ExpansionKind.SINGLE, current_node=node, prev=None) + if len(node.kwargs["names"]) == 1 + else cls.from_selector_multi_named(node) + ) + + @classmethod + def from_nth(cls, node: ExprNode) -> ExprMetadata: + return ( + cls(ExpansionKind.SINGLE, current_node=node, prev=None) + if len(node.kwargs["indices"]) == 1 + else cls.from_selector_multi_unnamed(node) + ) + + @classmethod + def from_selector_multi_named(cls, node: ExprNode) -> ExprMetadata: + # e.g. `nw.col('a', 'b')` + return cls(ExpansionKind.MULTI_NAMED, current_node=node, prev=None) + + @classmethod + def from_selector_multi_unnamed(cls, node: ExprNode) -> ExprMetadata: + # e.g. `nw.all()` + return cls(ExpansionKind.MULTI_UNNAMED, current_node=node, prev=None) + + @classmethod + def from_elementwise( + cls, node: ExprNode, *compliant_exprs: CompliantExprAny + ) -> ExprMetadata: + return combine_metadata( + *compliant_exprs, to_single_output=True, current_node=node, prev=None + ) + @property def is_filtration(self) -> bool: return not self.preserves_length and not self.is_scalar_like - def with_aggregation(self) -> ExprMetadata: + def with_aggregation(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply aggregations to scalar-like expressions." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.AGGREGATION, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops, preserves_length=False, is_elementwise=False, is_scalar_like=True, is_literal=False, + current_node=node, + prev=self, ) - def with_orderable_aggregation(self) -> ExprMetadata: + def with_orderable_aggregation( + self, node: ExprNode, _ce: CompliantExprAny + ) -> ExprMetadata: # Deprecated, used only in stable.v1. if self.is_scalar_like: # pragma: no cover msg = "Can't apply aggregations to scalar-like expressions." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.ORDERABLE_AGGREGATION, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops + 1, preserves_length=False, is_elementwise=False, is_scalar_like=True, is_literal=False, + current_node=node, + prev=self, ) - def with_elementwise_op(self) -> ExprMetadata: - return ExprMetadata( - self.expansion_kind, - ExprKind.ELEMENTWISE, - has_windows=self.has_windows, - n_orderable_ops=self.n_orderable_ops, - preserves_length=self.preserves_length, - is_elementwise=self.is_elementwise, - is_scalar_like=self.is_scalar_like, - is_literal=self.is_literal, + def with_elementwise( + self, + node: ExprNode, + compliant_expr: CompliantExprAny, + *compliant_expr_args: CompliantExprAny, + ) -> ExprMetadata: + return combine_metadata( + compliant_expr, + *compliant_expr_args, + to_single_output=False, + current_node=node, + prev=compliant_expr._metadata, ) - def with_window(self) -> ExprMetadata: + def with_window(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: # Window function which may (but doesn't have to) be used with `over(order_by=...)`. if self.is_scalar_like: msg = "Can't apply window (e.g. `rank`) to scalar-like expression." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.WINDOW, has_windows=self.has_windows, # The function isn't order-dependent (but, users can still use `order_by` if they wish!), # so we don't increment `n_orderable_ops`. @@ -336,25 +549,30 @@ def with_window(self) -> ExprMetadata: is_elementwise=False, is_scalar_like=False, is_literal=False, + current_node=node, + prev=self, ) - def with_orderable_window(self) -> ExprMetadata: + def with_orderable_window( + self, node: ExprNode, _ce: CompliantExprAny + ) -> ExprMetadata: # Window function which must be used with `over(order_by=...)`. if self.is_scalar_like: msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.ORDERABLE_WINDOW, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops + 1, preserves_length=self.preserves_length, is_elementwise=False, is_scalar_like=False, is_literal=False, + current_node=node, + prev=self, ) - def with_ordered_over(self) -> ExprMetadata: + def with_ordered_over(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: if self.has_windows: msg = "Cannot nest `over` statements." raise InvalidOperationError(msg) @@ -365,7 +583,10 @@ def with_ordered_over(self) -> ExprMetadata: ) raise InvalidOperationError(msg) n_orderable_ops = self.n_orderable_ops - if not n_orderable_ops and self.last_node is not ExprKind.WINDOW: + if ( + not n_orderable_ops + and next(self.op_nodes_reversed()).kind is not ExprKind.WINDOW + ): msg = ( "Cannot use `order_by` in `over` on expression which isn't orderable.\n" "If your expression is orderable, then make sure that `over(order_by=...)`\n" @@ -376,20 +597,23 @@ def with_ordered_over(self) -> ExprMetadata: " + `nw.col('price').diff().over(order_by='date') + 1`\n" ) raise InvalidOperationError(msg) - if self.last_node.is_orderable_window: + if next(self.op_nodes_reversed()).kind.is_orderable and n_orderable_ops > 0: n_orderable_ops -= 1 return ExprMetadata( self.expansion_kind, - ExprKind.OVER, has_windows=True, n_orderable_ops=n_orderable_ops, preserves_length=True, is_elementwise=False, is_scalar_like=False, is_literal=False, + current_node=node, + prev=self, ) - def with_partitioned_over(self) -> ExprMetadata: + def with_partitioned_over( + self, node: ExprNode, _ce: CompliantExprAny + ) -> ExprMetadata: if self.has_windows: msg = "Cannot nest `over` statements." raise InvalidOperationError(msg) @@ -401,110 +625,108 @@ def with_partitioned_over(self) -> ExprMetadata: raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.OVER, has_windows=True, n_orderable_ops=self.n_orderable_ops, preserves_length=True, is_elementwise=False, is_scalar_like=False, is_literal=False, + current_node=node, + prev=self, ) - def with_filtration(self) -> ExprMetadata: + def with_over(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: + if node.kwargs["order_by"]: + return self.with_ordered_over(node, _ce) + if not node.kwargs["partition_by"]: # pragma: no cover + msg = "At least one of `partition_by` or `order_by` must be specified." + raise InvalidOperationError(msg) + return self.with_partitioned_over(node, _ce) + + def with_filtration( + self, node: ExprNode, *compliant_exprs: CompliantExprAny + ) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression." raise InvalidOperationError(msg) + result_has_windows = any(x._metadata.has_windows for x in compliant_exprs) + result_n_orderable_ops = sum(x._metadata.n_orderable_ops for x in compliant_exprs) return ExprMetadata( self.expansion_kind, - ExprKind.FILTRATION, - has_windows=self.has_windows, - n_orderable_ops=self.n_orderable_ops, + has_windows=result_has_windows, + n_orderable_ops=result_n_orderable_ops, preserves_length=False, is_elementwise=False, is_scalar_like=False, is_literal=False, + current_node=node, + prev=self, ) - def with_orderable_filtration(self) -> ExprMetadata: + def with_orderable_filtration( + self, node: ExprNode, _ce: CompliantExprAny + ) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.ORDERABLE_FILTRATION, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops + 1, preserves_length=False, is_elementwise=False, is_scalar_like=False, is_literal=False, + current_node=node, + prev=self, ) - @staticmethod - def aggregation() -> ExprMetadata: - return ExprMetadata( - ExpansionKind.SINGLE, - ExprKind.AGGREGATION, - is_elementwise=False, - preserves_length=False, - is_scalar_like=True, - ) - - @staticmethod - def literal() -> ExprMetadata: - return ExprMetadata( - ExpansionKind.SINGLE, - ExprKind.LITERAL, - is_elementwise=False, - preserves_length=False, - is_literal=True, - is_scalar_like=True, - ) - - @staticmethod - def selector_single() -> ExprMetadata: - # e.g. `nw.col('a')`, `nw.nth(0)` - return ExprMetadata(ExpansionKind.SINGLE, ExprKind.ELEMENTWISE) - - @staticmethod - def selector_multi_named() -> ExprMetadata: - # e.g. `nw.col('a', 'b')` - return ExprMetadata(ExpansionKind.MULTI_NAMED, ExprKind.ELEMENTWISE) - - @staticmethod - def selector_multi_unnamed() -> ExprMetadata: - # e.g. `nw.all()` - return ExprMetadata(ExpansionKind.MULTI_UNNAMED, ExprKind.ELEMENTWISE) - - @classmethod - def from_binary_op(cls, lhs: Expr, rhs: IntoExpr, /) -> ExprMetadata: - # We may be able to allow multi-output rhs in the future: - # https://github.com/narwhals-dev/narwhals/issues/2244. - return combine_metadata( - lhs, rhs, str_as_lit=True, allow_multi_output=False, to_single_output=False - ) - - @classmethod - def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata: - return combine_metadata( - *exprs, str_as_lit=False, allow_multi_output=True, to_single_output=True - ) + def op_nodes_reversed(self) -> Iterator[ExprNode]: + for node in self.iter_nodes_reversed(): + if node.name.startswith(("name.", "alias")): + # Skip nodes which only do aliasing. + continue + yield node + + +KIND_TO_METADATA_CONSTRUCTOR: dict[ExprKind, Callable[[ExprNode], ExprMetadata]] = { + ExprKind.AGGREGATION: ExprMetadata.from_aggregation, + ExprKind.ALL: ExprMetadata.from_selector_multi_unnamed, + ExprKind.ELEMENTWISE: ExprMetadata.from_elementwise, + ExprKind.EXCLUDE: ExprMetadata.from_selector_multi_unnamed, + ExprKind.SERIES: ExprMetadata.from_series, + ExprKind.COL: ExprMetadata.from_col, + ExprKind.LITERAL: ExprMetadata.from_literal, + ExprKind.NTH: ExprMetadata.from_nth, + ExprKind.SELECTOR: ExprMetadata.from_selector_multi_unnamed, +} + +KIND_TO_METADATA_UPDATER: dict[ExprKind, Callable[..., ExprMetadata]] = { + ExprKind.AGGREGATION: ExprMetadata.with_aggregation, + ExprKind.ELEMENTWISE: ExprMetadata.with_elementwise, + ExprKind.FILTRATION: ExprMetadata.with_filtration, + ExprKind.ORDERABLE_AGGREGATION: ExprMetadata.with_orderable_aggregation, + ExprKind.ORDERABLE_FILTRATION: ExprMetadata.with_orderable_filtration, + ExprKind.OVER: ExprMetadata.with_over, + ExprKind.ORDERABLE_WINDOW: ExprMetadata.with_orderable_window, + ExprKind.WINDOW: ExprMetadata.with_window, +} def combine_metadata( - *args: IntoExpr | object | None, - str_as_lit: bool, - allow_multi_output: bool, + *compliant_exprs: CompliantExprAny, to_single_output: bool, + current_node: ExprNode, + prev: ExprMetadata | None, ) -> ExprMetadata: """Combine metadata from `args`. Arguments: - args: Arguments, maybe expressions, literals, or Series. - str_as_lit: Whether to interpret strings as literals or as column names. - allow_multi_output: Whether to allow multi-output inputs. + compliant_exprs: Expression arguments. to_single_output: Whether the result is always single-output, regardless of the inputs (e.g. `nw.sum_horizontal`). + current_node: The current node being added. + prev: ExprMetadata of previous node. """ n_filtrations = 0 result_expansion_kind = ExpansionKind.SINGLE @@ -519,97 +741,151 @@ def combine_metadata( # result is literal if all inputs are literal result_is_literal = True - for i, arg in enumerate(args): - if (isinstance(arg, str) and not str_as_lit) or is_series(arg): - result_preserves_length = True - result_is_scalar_like = False - result_is_literal = False - elif is_expr(arg): - metadata = arg._metadata - if metadata.expansion_kind.is_multi_output(): - expansion_kind = metadata.expansion_kind - if i > 0 and not allow_multi_output: - # Left-most argument is always allowed to be multi-output. - msg = ( - "Multi-output expressions (e.g. nw.col('a', 'b'), nw.all()) " - "are not supported in this context." - ) - raise MultiOutputExpressionError(msg) - if not to_single_output: - result_expansion_kind = ( - result_expansion_kind & expansion_kind - if i > 0 - else expansion_kind - ) - - result_has_windows |= metadata.has_windows - result_n_orderable_ops += metadata.n_orderable_ops - result_preserves_length |= metadata.preserves_length - result_is_elementwise &= metadata.is_elementwise - result_is_scalar_like &= metadata.is_scalar_like - result_is_literal &= metadata.is_literal - n_filtrations += int(metadata.is_filtration) - + for i, ce in enumerate(compliant_exprs): + metadata = ce._metadata + assert metadata is not None # noqa: S101 + if metadata.expansion_kind.is_multi_output(): + expansion_kind = metadata.expansion_kind + if not to_single_output: + result_expansion_kind = ( + result_expansion_kind & expansion_kind if i > 0 else expansion_kind + ) + + result_has_windows |= metadata.has_windows + result_n_orderable_ops += metadata.n_orderable_ops + result_preserves_length |= metadata.preserves_length + result_is_elementwise &= metadata.is_elementwise + result_is_scalar_like &= metadata.is_scalar_like + result_is_literal &= metadata.is_literal + n_filtrations += int(metadata.is_filtration) if n_filtrations > 1: msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation" raise InvalidOperationError(msg) if result_preserves_length and n_filtrations: msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations" raise InvalidOperationError(msg) - return ExprMetadata( result_expansion_kind, - # n-ary operations align positionally, and so the last node is elementwise. - ExprKind.ELEMENTWISE, has_windows=result_has_windows, n_orderable_ops=result_n_orderable_ops, preserves_length=result_preserves_length, is_elementwise=result_is_elementwise, is_scalar_like=result_is_scalar_like, is_literal=result_is_literal, + current_node=current_node, + prev=prev, ) -def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> None: +def check_expressions_preserve_length( + *args: CompliantExprAny, function_name: str +) -> None: # Raise if any argument in `args` isn't length-preserving. # For Series input, we don't raise (yet), we let such checks happen later, # as this function works lazily and so can't evaluate lengths. - from narwhals.series import Series - if not all( - (is_expr(x) and x._metadata.preserves_length) or isinstance(x, (str, Series)) - for x in args - ): + if not all(x._metadata.preserves_length for x in args): msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'." raise InvalidOperationError(msg) -def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool: - # Raise if any argument in `args` isn't an aggregation or literal. - # For Series input, we don't raise (yet), we let such checks happen later, - # as this function works lazily and so can't evaluate lengths. - exprs = chain(args, kwargs.values()) - return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs) +def _parse_into_expr( + arg: IntoExpr | NonNestedLiteral | _1DArray, + *, + str_as_lit: bool = False, + backend: Any = None, + allow_literal: bool = True, +) -> Expr: + from narwhals.functions import col, lit, new_series + + if isinstance(arg, str) and not str_as_lit: + return col(arg) + if is_numpy_array_1d(arg): + return new_series("", arg, backend=backend)._to_expr() + if is_series(arg): + return arg._to_expr() + if is_expr(arg): + return arg + if not allow_literal: + raise InvalidIntoExprError.from_invalid_type(type(arg)) + return lit(arg) + + +def evaluate_into_exprs( + *exprs: IntoExpr | NonNestedLiteral | _1DArray, + ns: CompliantNamespaceAny, + str_as_lit: bool, + allow_multi_output: bool, +) -> Iterator[CompliantExprAny]: + for expr in exprs: + ret = _parse_into_expr( + expr, str_as_lit=str_as_lit, backend=ns._implementation + )._to_compliant_expr(ns) + if not allow_multi_output and ret._metadata.expansion_kind.is_multi_output(): + msg = "Multi-output expressions are not allowed in this context." + raise MultiOutputExpressionError(msg) + yield ret + + +def maybe_broadcast_ces(*compliant_exprs: CompliantExprAny) -> list[CompliantExprAny]: + broadcast = any(not is_scalar_like(ce) for ce in compliant_exprs) + results: list[CompliantExprAny] = [] + for compliant_expr in compliant_exprs: + if broadcast and is_scalar_like(compliant_expr): + _compliant_expr: CompliantExprAny = compliant_expr.broadcast() + # Make sure to preserve metadata. + _compliant_expr._opt_metadata = compliant_expr._metadata + results.append(_compliant_expr) + else: + results.append(compliant_expr) + return results + + +def evaluate_root_node(node: ExprNode, ns: CompliantNamespaceAny) -> CompliantExprAny: + if node.name in {"col", "exclude"}: + # There's too much potential for Sequence[str] vs str bugs, so we pass down + # `names` positionally rather than as a sequence of strings. + ce = getattr(ns, node.name)(*node.kwargs["names"]) + ces = [] + else: + if "." in node.name: + module, method = node.name.split(".") + func = getattr(getattr(ns, module), method) + else: + func = getattr(ns, node.name) + ces = maybe_broadcast_ces( + *evaluate_into_exprs( + *node.exprs, + ns=ns, + str_as_lit=node.str_as_lit, + allow_multi_output=node.allow_multi_output, + ) + ) + ce = cast("CompliantExprAny", func(*ces, **node.kwargs)) + md = ExprMetadata.from_node(node, *ces) + ce._opt_metadata = md + return ce -def apply_n_ary_operation( - plx: CompliantNamespaceAny, - n_ary_function: Callable[..., CompliantExprAny], - *comparands: IntoExpr | NonNestedLiteral | _1DArray, - str_as_lit: bool, +def evaluate_node( + compliant_expr: CompliantExprAny, node: ExprNode, ns: CompliantNamespaceAny ) -> CompliantExprAny: - parse = plx.parse_into_expr - compliant_exprs = (parse(into, str_as_lit=str_as_lit) for into in comparands) - kinds = [ - ExprKind.from_into_expr(comparand, str_as_lit=str_as_lit) - for comparand in comparands - ] - - broadcast = any(not kind.is_scalar_like for kind in kinds) - compliant_exprs = ( - compliant_expr.broadcast(kind) - if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind) - else compliant_expr - for compliant_expr, kind in zip_strict(compliant_exprs, kinds) + md: ExprMetadata = compliant_expr._metadata + compliant_expr, *compliant_expr_args = maybe_broadcast_ces( + compliant_expr, + *evaluate_into_exprs( + *node.exprs, + ns=ns, + str_as_lit=node.str_as_lit, + allow_multi_output=node.allow_multi_output, + ), ) - return n_ary_function(*compliant_exprs) + md = md.with_node(node, compliant_expr, *compliant_expr_args) + if "." in node.name: + accessor, method = node.name.split(".") + func = getattr(getattr(compliant_expr, accessor), method) + else: + func = getattr(compliant_expr, node.name) + ret = cast("CompliantExprAny", func(*compliant_expr_args, **node.kwargs)) + ret._opt_metadata = md + return ret diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 72edae8464..f5470700a1 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -423,12 +423,5 @@ def sink_parquet(self, file: str | Path | BytesIO) -> None: raise NotImplementedError(msg) self.native.to_parquet(file) - gather_every = not_implemented.deprecated( - "`LazyFrame.gather_every` is deprecated and will be removed in a future version." - ) - tail = not_implemented.deprecated( - "`LazyFrame.tail` is deprecated and will be removed in a future version." - ) - # Intentionally not implemented, as Ibis does its own expression rewriting. _evaluate_window_expr = not_implemented() diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index 5aa5f42396..de322295eb 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast import ibis @@ -41,7 +41,6 @@ EvalSeries, WindowFunction, ) - from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._ibis.dataframe import IbisLazyFrame from narwhals._ibis.namespace import IbisNamespace from narwhals._utils import _LimitedContext @@ -69,7 +68,6 @@ def __init__( self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._metadata: ExprMetadata | None = None self._window_function: IbisWindowFunction | None = window_function @property @@ -132,7 +130,7 @@ def __narwhals_namespace__(self) -> IbisNamespace: # pragma: no cover return IbisNamespace(version=self._version) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + def broadcast(self) -> Self: # Ibis does its own broadcasting. return self @@ -184,11 +182,11 @@ def func(df: IbisLazyFrame) -> Sequence[ir.Column]: version=context._version, ) - def _with_binary(self, op: Callable[..., ir.Value], other: Self | Any) -> Self: + def _with_binary(self, op: Callable[..., ir.Value], other: Self) -> Self: return self._with_callable(op, other=other) def _with_elementwise( - self, op: Callable[..., ir.Value], /, **expressifiable_args: Self | Any + self, op: Callable[..., ir.Value], /, **expressifiable_args: Self ) -> Self: return self._with_callable(op, **expressifiable_args) @@ -208,18 +206,6 @@ def quantile( raise NotImplementedError(msg) return self._with_callable(lambda expr: expr.quantile(quantile)) - def clip(self, lower_bound: Any, upper_bound: Any) -> Self: - def _clip( - expr: ir.NumericValue, lower: Any | None = None, upper: Any | None = None - ) -> ir.NumericValue: - return expr.clip(lower=lower, upper=upper) - - if lower_bound is None: - return self._with_callable(_clip, upper=upper_bound) - if upper_bound is None: - return self._with_callable(_clip, lower=lower_bound) - return self._with_callable(_clip, lower=lower_bound, upper=upper_bound) - def n_unique(self) -> Self: return self._with_callable( lambda expr: expr.nunique() + expr.isnull().any().cast("int8") @@ -240,7 +226,7 @@ def null_count(self) -> Self: return self._with_callable(lambda expr: expr.isnull().sum()) def is_nan(self) -> Self: - def func(expr: ir.FloatingValue | Any) -> ir.Value: + def func(expr: ir.FloatingValue) -> ir.Value: otherwise = expr.isnan() if is_floating(expr.type()) else False return ibis.ifelse(expr.isnull(), None, otherwise) @@ -252,7 +238,7 @@ def is_finite(self) -> Self: def is_in(self, other: Sequence[Any]) -> Self: return self._with_callable(lambda expr: expr.isin(other)) - def fill_null(self, value: Self | Any, strategy: Any, limit: int | None) -> Self: + def fill_null(self, value: Self | None, strategy: Any, limit: int | None) -> Self: # Ibis doesn't yet allow ignoring nulls in first/last with window functions, which makes forward/backward # strategies inconsistent when there are nulls present: https://github.com/ibis-project/ibis/issues/9539 if strategy is not None: @@ -265,6 +251,7 @@ def fill_null(self, value: Self | Any, strategy: Any, limit: int | None) -> Self def _fill_null(expr: ir.Value, value: ir.Scalar) -> ir.Value: return expr.fill_null(value) + assert value is not None # noqa: S101 return self._with_callable(_fill_null, value=value) def cast(self, dtype: IntoDType) -> Self: diff --git a/narwhals/_ibis/expr_str.py b/narwhals/_ibis/expr_str.py index 189db5823b..26a1f0c0ed 100644 --- a/narwhals/_ibis/expr_str.py +++ b/narwhals/_ibis/expr_str.py @@ -40,12 +40,8 @@ def fn(expr: ir.StringColumn) -> ir.StringValue: return fn - def replace_all( - self, pattern: str, value: str | IbisExpr, *, literal: bool - ) -> IbisExpr: + def replace_all(self, value: IbisExpr, pattern: str, *, literal: bool) -> IbisExpr: fn = self._replace_all_literal if literal else self._replace_all - if isinstance(value, str): - return self.compliant._with_callable(fn(pattern, value)) return self.compliant._with_elementwise( lambda expr, value: fn(pattern, value)(expr), value=value ) diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index 2c73560894..d3116edd4f 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -17,8 +17,7 @@ from narwhals._ibis.selectors import IbisSelectorNamespace from narwhals._ibis.utils import function, lit, narwhals_to_native_dtype from narwhals._sql.namespace import SQLNamespace -from narwhals._sql.when_then import SQLThen, SQLWhen -from narwhals._utils import Implementation, requires +from narwhals._utils import Implementation if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -110,10 +109,6 @@ def func(cols: Iterable[ir.Value]) -> ir.Value: return self._expr._from_elementwise_horizontal_op(func, *exprs) - @requires.backend_version((10, 0)) - def when(self, predicate: IbisExpr) -> IbisWhen: - return IbisWhen.from_expr(predicate, context=self) - def lit(self, value: Any, dtype: IntoDType | None) -> IbisExpr: def func(_df: IbisLazyFrame) -> Sequence[ir.Value]: ibis_dtype = narwhals_to_native_dtype(dtype, self._version) if dtype else None @@ -136,27 +131,3 @@ def func(_df: IbisLazyFrame) -> list[ir.Value]: alias_output_names=None, version=self._version, ) - - -class IbisWhen(SQLWhen["IbisLazyFrame", "ir.Value", IbisExpr]): - lit = lit - - @property - def _then(self) -> type[IbisThen]: - return IbisThen - - def __call__(self, df: IbisLazyFrame) -> Sequence[ir.Value]: - is_expr = self._condition._is_expr - condition = df._evaluate_expr(self._condition) - then_ = self._then_value - then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_) - other_ = self._otherwise_value - if other_ is None: - result = ibis.cases((condition, then)) - else: - otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_) - result = ibis.cases((condition, then), else_=otherwise) - return [result] - - -class IbisThen(SQLThen["IbisLazyFrame", "ir.Value", IbisExpr], IbisExpr): ... diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 290181ab78..13e2351a2c 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -418,7 +418,7 @@ def simple_select(self, *column_names: str) -> Self: ) def select(self, *exprs: PandasLikeExpr) -> Self: - new_series = self._evaluate_into_exprs(*exprs) + new_series = self._evaluate_exprs(*exprs) if not new_series: # return empty dataframe, like Polars does return self._with_native(type(self.native)(), validate_column_names=False) @@ -461,15 +461,14 @@ def row(self, index: int) -> tuple[Any, ...]: return tuple(x for x in self.native.iloc[index]) def filter(self, predicate: PandasLikeExpr) -> Self: - # `[0]` is safe as the predicate's expression only returns a single column - mask = self._evaluate_into_exprs(predicate)[0] + mask = self._evaluate_single_output_expr(predicate) mask_native = self._extract_comparand(mask) return self._with_native( self.native.loc[mask_native], validate_column_names=False ) def with_columns(self, *exprs: PandasLikeExpr) -> Self: - columns = self._evaluate_into_exprs(*exprs) + columns = self._evaluate_exprs(*exprs) if not columns and len(self) == 0: return self name_columns: dict[str, PandasLikeSeries] = {s.name: s for s in columns} diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index d5332397c8..91bf159493 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from narwhals._compliant import EagerExpr from narwhals._expression_parsing import evaluate_output_names_and_aliases @@ -19,9 +19,7 @@ EvalNames, EvalSeries, NarwhalsAggregation, - ScalarKwargs, ) - from narwhals._expression_parsing import ExprMetadata from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.namespace import PandasLikeNamespace from narwhals._utils import Implementation, Version, _LimitedContext @@ -50,7 +48,7 @@ def window_kwargs_to_pandas_equivalent( # noqa: C901 - function_name: str, kwargs: ScalarKwargs + function_name: str, kwargs: dict[str, Any] ) -> dict[str, PythonLiteral]: if function_name == "shift": assert "n" in kwargs # noqa: S101 @@ -124,23 +122,16 @@ def __init__( self, call: EvalSeries[PandasLikeDataFrame, PandasLikeSeries], *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[PandasLikeDataFrame], alias_output_names: AliasNames | None, implementation: Implementation, version: Version, - scalar_kwargs: ScalarKwargs | None = None, ) -> None: self._call = call - self._depth = depth - self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._implementation = implementation self._version = version - self._scalar_kwargs = scalar_kwargs or {} - self._metadata: ExprMetadata | None = None def __narwhals_namespace__(self) -> PandasLikeNamespace: from narwhals._pandas_like.namespace import PandasLikeNamespace @@ -154,7 +145,6 @@ def from_column_names( /, *, context: _LimitedContext, - function_name: str = "", ) -> Self: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: try: @@ -173,8 +163,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return cls( func, - depth=0, - function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, implementation=context._implementation, @@ -192,8 +180,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return cls( func, - depth=0, - function_name="nth", evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, implementation=context._implementation, @@ -213,20 +199,19 @@ def ewm_mean( ) -> Self: return self._reuse_series( "ewm_mean", - scalar_kwargs={ - "com": com, - "span": span, - "half_life": half_life, - "alpha": alpha, - "adjust": adjust, - "min_samples": min_samples, - "ignore_nulls": ignore_nulls, - }, + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_samples=min_samples, + ignore_nulls=ignore_nulls, ) def over( # noqa: C901, PLR0915 self, partition_by: Sequence[str], order_by: Sequence[str] ) -> Self: + op_nodes_reversed = list(self._metadata.op_nodes_reversed()) if not partition_by: # e.g. `nw.col('a').cum_sum().order_by(key)` # We can always easily support this as it doesn't require grouping. @@ -253,7 +238,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: for s in results: s._scatter_in_place(sorting_indices, s) return results - elif not self._is_elementary(): + elif len(op_nodes_reversed) > 2: msg = ( "Only elementary expressions are supported for `.over` in pandas-like backends.\n\n" "Please see: " @@ -261,9 +246,14 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: ) raise NotImplementedError(msg) else: - function_name = PandasLikeGroupBy._leaf_name(self) + assert op_nodes_reversed # noqa: S101 + leaf_node = op_nodes_reversed[0] + function_name = leaf_node.name + pandas_agg = PandasLikeGroupBy._REMAP_AGGS.get( + cast("NarwhalsAggregation", function_name) + ) pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get( - function_name, PandasLikeGroupBy._REMAP_AGGS.get(function_name) + function_name, pandas_agg ) if pandas_function_name is None: msg = ( @@ -272,8 +262,9 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: f"and {', '.join(PandasLikeGroupBy._REMAP_AGGS)}." ) raise NotImplementedError(msg) + scalar_kwargs = leaf_node.kwargs pandas_kwargs = window_kwargs_to_pandas_equivalent( - function_name, self._scalar_kwargs + function_name, scalar_kwargs ) def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, PLR0912, PLR0914, PLR0915 @@ -284,10 +275,10 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, df = df.with_columns(~plx.col(*output_names).is_null()) if function_name.startswith("cum_"): - assert "reverse" in self._scalar_kwargs # noqa: S101 - reverse = self._scalar_kwargs["reverse"] + assert "reverse" in scalar_kwargs # noqa: S101 + reverse = scalar_kwargs["reverse"] else: - assert "reverse" not in self._scalar_kwargs # noqa: S101 + assert "reverse" not in scalar_kwargs # noqa: S101 reverse = False if order_by: @@ -306,9 +297,9 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, if function_name.startswith("rolling"): rolling = grouped[list(output_names)].rolling(**pandas_kwargs) if pandas_function_name in {"std", "var"}: - assert "ddof" in self._scalar_kwargs # noqa: S101 + assert "ddof" in scalar_kwargs # noqa: S101 res_native = getattr(rolling, pandas_function_name)( - ddof=self._scalar_kwargs["ddof"] + ddof=scalar_kwargs["ddof"] ) else: res_native = getattr(rolling, pandas_function_name)() @@ -325,13 +316,13 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, assert pandas_function_name is not None # help mypy # noqa: S101 res_native = getattr(ewm, pandas_function_name)() elif function_name == "fill_null": - assert "strategy" in self._scalar_kwargs # noqa: S101 - assert "limit" in self._scalar_kwargs # noqa: S101 + assert "strategy" in scalar_kwargs # noqa: S101 + assert "limit" in scalar_kwargs # noqa: S101 df_grouped = grouped[list(output_names)] - if self._scalar_kwargs["strategy"] == "forward": - res_native = df_grouped.ffill(limit=self._scalar_kwargs["limit"]) - elif self._scalar_kwargs["strategy"] == "backward": - res_native = df_grouped.bfill(limit=self._scalar_kwargs["limit"]) + if scalar_kwargs["strategy"] == "forward": + res_native = df_grouped.ffill(limit=scalar_kwargs["limit"]) + elif scalar_kwargs["strategy"] == "backward": + res_native = df_grouped.bfill(limit=scalar_kwargs["limit"]) else: # pragma: no cover # This is deprecated in pandas. Indeed, `nw.col('a').fill_null(3).over('b')` # does not seem very useful, and DuckDB doesn't support it either. @@ -374,8 +365,6 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, implementation=self._implementation, diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index d3a84ae72f..def77bd15b 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -122,7 +122,8 @@ def _getitem_aggs( ) elif self.is_mode(): compliant = group_by.compliant - if (keep := self.kwargs.get("keep")) != "any": # pragma: no cover + node_kwargs = group_by._kwargs(self.expr) + if (keep := node_kwargs.get("keep")) != "any": # pragma: no cover msg = ( f"`Expr.mode(keep='{keep}')` is not implemented in group by context for " f"backend {compliant._implementation}\n\n" @@ -132,7 +133,7 @@ def _getitem_aggs( cols = list(names) native = compliant.native - keys, kwargs = group_by._keys, group_by._kwargs + keys, kwargs = group_by._keys, group_by._group_by_kwargs # Implementation based on the following suggestion: # https://github.com/pandas-dev/pandas/issues/19254#issuecomment-778661578 @@ -175,11 +176,7 @@ def is_mode(self) -> bool: def is_top_level_function(self) -> bool: # e.g. `nw.len()`. - return self.expr._depth == 0 - - @property - def kwargs(self) -> ScalarKwargs: - return self.expr._scalar_kwargs + return len(list(self.expr._metadata.op_nodes_reversed())) == 1 @property def leaf_name(self) -> NarwhalsAggregation | Any: @@ -191,9 +188,10 @@ def leaf_name(self) -> NarwhalsAggregation | Any: def native_agg(self) -> _NativeAgg: """Return a partial `DataFrameGroupBy` method, missing only `self`.""" native_name = PandasLikeGroupBy._remap_expr_name(self.leaf_name) + last_node = next(self.expr._metadata.op_nodes_reversed()) if self.leaf_name in _REMAP_ORDERED_INDEX: return methodcaller("nth", n=_REMAP_ORDERED_INDEX[self.leaf_name]) - return _native_agg(native_name, **self.kwargs) + return _native_agg(native_name, **last_node.kwargs) class PandasLikeGroupBy( @@ -226,7 +224,7 @@ class PandasLikeGroupBy( _output_key_names: list[str] """Stores the **original** version of group keys.""" - _kwargs: Mapping[str, bool] + _group_by_kwargs: Mapping[str, bool] """Stores keyword arguments for `DataFrame.groupby` other than `by`.""" @property @@ -254,13 +252,15 @@ def __init__( if set(native.index.names).intersection(self.compliant.columns): native = native.reset_index(drop=True) - self._kwargs = { + self._group_by_kwargs = { "sort": False, "as_index": True, "dropna": drop_null_keys, "observed": True, } - self._grouped: NativeGroupBy = native.groupby(self._keys.copy(), **self._kwargs) + self._grouped: NativeGroupBy = native.groupby( + self._keys.copy(), **self._group_by_kwargs + ) def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: all_aggs_are_simple = True diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index fff20e292b..96b7a22290 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -6,7 +6,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen +from narwhals._compliant import EagerNamespace from narwhals._expression_parsing import ( combine_alias_output_names, combine_evaluate_output_names, @@ -24,7 +24,6 @@ from typing_extensions import TypeAlias - from narwhals._compliant.typing import ScalarKwargs from narwhals._utils import Implementation, Version from narwhals.typing import IntoDType, NonNestedLiteral @@ -79,8 +78,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="coalesce", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -100,8 +97,6 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: return PandasLikeExpr( lambda df: [_lit_pandas_series(df)], - depth=0, - function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, implementation=self._implementation, @@ -115,8 +110,6 @@ def len(self) -> PandasLikeExpr: [len(df._native_frame)], name="len", index=[0], context=self ) ], - depth=0, - function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, implementation=self._implementation, @@ -132,8 +125,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -164,8 +155,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -196,8 +185,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -212,8 +199,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -234,8 +219,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -256,8 +239,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -308,9 +289,6 @@ def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram return self._concat(dfs, axis=VERTICAL, copy=False) return self._concat(dfs, axis=VERTICAL) - def when(self, predicate: PandasLikeExpr) -> PandasWhen[NativeSeriesT]: - return PandasWhen[NativeSeriesT].from_expr(predicate, context=self) - def concat_str( self, *exprs: PandasLikeExpr, separator: str, ignore_nulls: bool ) -> PandasLikeExpr: @@ -352,13 +330,20 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) + def _if_then_else( + self, + when: NativeSeriesT, + then: NativeSeriesT, + otherwise: NativeSeriesT | None = None, + ) -> NativeSeriesT: + where: Incomplete = then.where + return where(when) if otherwise is None else where(when, otherwise) + class _NativeConcat(Protocol[NativeDataFrameT, NativeSeriesT]): @overload @@ -397,31 +382,3 @@ def __call__( axis: Axis, copy: bool | None = None, ) -> NativeDataFrameT | NativeSeriesT: ... - - -class PandasWhen( - EagerWhen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, NativeSeriesT] -): - @property - # Signature of "_then" incompatible with supertype "CompliantWhen" - # ArrowWhen seems to follow the same pattern, but no mypy complaint there? - def _then(self) -> type[PandasThen]: # type: ignore[override] - return PandasThen - - def _if_then_else( - self, - when: NativeSeriesT, - then: NativeSeriesT, - otherwise: NativeSeriesT | NonNestedLiteral, - ) -> NativeSeriesT: - where: Incomplete = then.where - return where(when) if otherwise is None else where(when, otherwise) - - -class PandasThen( - CompliantThen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, PandasWhen], - PandasLikeExpr, -): - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "whenthen" diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 7e68108ed3..b2462561cf 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -6,7 +6,6 @@ from narwhals._pandas_like.expr import PandasLikeExpr if TYPE_CHECKING: - from narwhals._compliant.typing import ScalarKwargs from narwhals._pandas_like.dataframe import PandasLikeDataFrame # noqa: F401 from narwhals._pandas_like.series import PandasLikeSeries # noqa: F401 @@ -22,15 +21,9 @@ def _selector(self) -> type[PandasSelector]: class PandasSelector( # type: ignore[misc] CompliantSelector["PandasLikeDataFrame", "PandasLikeSeries"], PandasLikeExpr ): - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "selector" - def _to_expr(self) -> PandasLikeExpr: return PandasLikeExpr( self._call, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, implementation=self._implementation, diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 40921715f0..aab6fbdfa0 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -50,12 +50,10 @@ IntoDType, ModeKeepStrategy, NonNestedLiteral, - NumericLiteral, PythonLiteral, RankMethod, RollingInterpolationMethod, SizedMultiIndexSelector, - TemporalLiteral, _1DArray, _SliceIndex, ) @@ -876,21 +874,9 @@ def to_dummies(self, *, separator: str, drop_first: bool) -> PandasLikeDataFrame def gather_every(self, n: int, offset: int) -> Self: return self._with_native(self.native.iloc[offset::n]) - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: - _, lower = ( - align_and_extract_native(self, lower_bound) - if lower_bound is not None - else (None, None) - ) - _, upper = ( - align_and_extract_native(self, upper_bound) - if upper_bound is not None - else (None, None) - ) + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: + _, lower = align_and_extract_native(self, lower_bound) + _, upper = align_and_extract_native(self, upper_bound) impl = self._implementation kwargs: dict[str, Any] = {"axis": 0} if impl.is_modin() else {} result = self.native @@ -908,6 +894,36 @@ def clip( return self._with_native(result.clip(lower, upper, **kwargs)) + def clip_lower(self, lower_bound: Self) -> Self: + _, lower = align_and_extract_native(self, lower_bound) + impl = self._implementation + kwargs: dict[str, Any] = {"axis": 0} if impl.is_modin() else {} + result = self.native + + if not impl.is_pandas() and self._is_native(lower): # pragma: no cover + # Workaround for both cudf and modin when clipping with a series + # * cudf: https://github.com/rapidsai/cudf/issues/17682 + # * modin: https://github.com/modin-project/modin/issues/7415 + result = result.where(result >= lower, lower) + lower = None + + return self._with_native(result.clip(lower, **kwargs)) + + def clip_upper(self, upper_bound: Self) -> Self: + _, upper = align_and_extract_native(self, upper_bound) + impl = self._implementation + kwargs: dict[str, Any] = {"axis": 0} if impl.is_modin() else {} + result = self.native + + if not impl.is_pandas() and self._is_native(upper): # pragma: no cover + # Workaround for both cudf and modin when clipping with a series + # * cudf: https://github.com/rapidsai/cudf/issues/17682 + # * modin: https://github.com/modin-project/modin/issues/7415 + result = result.where(result <= upper, upper) + upper = None + + return self._with_native(result.clip(upper=upper, **kwargs)) + def to_arrow(self) -> pa.Array[Any]: if self._implementation is Implementation.CUDF: return self.native.to_arrow() diff --git a/narwhals/_pandas_like/series_str.py b/narwhals/_pandas_like/series_str.py index 7fb820598a..d73c46dc20 100644 --- a/narwhals/_pandas_like/series_str.py +++ b/narwhals/_pandas_like/series_str.py @@ -3,7 +3,11 @@ from typing import TYPE_CHECKING, Any from narwhals._compliant.any_namespace import StringNamespace -from narwhals._pandas_like.utils import PandasLikeSeriesNamespace, is_dtype_pyarrow +from narwhals._pandas_like.utils import ( + PandasLikeSeriesNamespace, + align_and_extract_native, + is_dtype_pyarrow, +) if TYPE_CHECKING: from narwhals._pandas_like.series import PandasLikeSeries @@ -16,21 +20,21 @@ def len_chars(self) -> PandasLikeSeries: return self.with_native(self.native.str.len()) def replace( - self, pattern: str, value: str, *, literal: bool, n: int + self, value: PandasLikeSeries, pattern: str, *, literal: bool, n: int ) -> PandasLikeSeries: - try: - series = self.native.str.replace( - pat=pattern, repl=value, n=n, regex=not literal - ) - except TypeError as e: - if not isinstance(value, str): - msg = f"{self.compliant._implementation} backed `.str.replace` only supports str replacement values" - raise TypeError(msg) from e - raise + _, value_native = align_and_extract_native(self.compliant, value) + if not isinstance(value_native, str): + msg = f"{self.compliant._implementation} backed `.str.replace` only supports str replacement values" + raise TypeError(msg) + series = self.native.str.replace( + pat=pattern, repl=value_native, n=n, regex=not literal + ) return self.with_native(series) - def replace_all(self, pattern: str, value: str, *, literal: bool) -> PandasLikeSeries: - return self.replace(pattern, value, literal=literal, n=-1) + def replace_all( + self, value: PandasLikeSeries, pattern: str, *, literal: bool + ) -> PandasLikeSeries: + return self.replace(value, pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> PandasLikeSeries: return self.with_native(self.native.str.strip(characters)) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 675eabe29a..05c54c0c85 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -170,6 +170,7 @@ def align_and_extract_native( if isinstance(rhs, list): msg = "Expected Series or scalar, got list." raise TypeError(msg) + # `rhs` must be scalar, so just leave it as-is return lhs.native, rhs diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 1e147a829d..5857b92350 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal +from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast import polars as pl @@ -24,11 +24,12 @@ from typing_extensions import Self from narwhals._compliant.typing import Accessor - from narwhals._expression_parsing import ExprKind, ExprMetadata + from narwhals._expression_parsing import ExprMetadata from narwhals._polars.dataframe import Method from narwhals._polars.namespace import PolarsNamespace + from narwhals._polars.series import PolarsSeries from narwhals._utils import Version - from narwhals.typing import IntoDType, ModeKeepStrategy, NumericLiteral + from narwhals.typing import IntoDType, ModeKeepStrategy class PolarsExpr: @@ -36,11 +37,14 @@ class PolarsExpr: _implementation: Implementation = Implementation.POLARS _version: Version _native_expr: pl.Expr - _metadata: ExprMetadata | None = None _evaluate_output_names: Any _alias_output_names: Any __call__: Any + @classmethod + def _from_series(cls, series: PolarsSeries) -> Self: + return cls(series.native, version=series._version) # type: ignore[arg-type] + # CompliantExpr + builtin descriptor # TODO @dangotbanned: Remove in #2713 @classmethod @@ -77,10 +81,15 @@ def __repr__(self) -> str: # pragma: no cover def _with_native(self, expr: pl.Expr) -> Self: return self.__class__(expr, self._version) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + def broadcast(self) -> Self: # Let Polars do its thing. return self + @property + def _metadata(self) -> ExprMetadata: + assert self._opt_metadata is not None # noqa: S101 + return cast("ExprMetadata", self._opt_metadata) + def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: pos, kwds = extract_args_kwargs(args, kwargs) @@ -96,6 +105,14 @@ def cast(self, dtype: IntoDType) -> Self: dtype_pl = narwhals_to_native_dtype(dtype, self._version) return self._with_native(self.native.cast(dtype_pl)) + def clip_lower(self, lower_bound: PolarsExpr) -> Self: + lower_native = extract_native(lower_bound) + return self._with_native(self.native.clip(lower_native)) + + def clip_upper(self, upper_bound: PolarsExpr) -> Self: + upper_native = extract_native(upper_bound) + return self._with_native(self.native.clip(None, upper_native)) + def ewm_mean( self, *, @@ -205,11 +222,11 @@ def replace_strict( native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl) return self._with_native(native) - def __eq__(self, other: object) -> Self: # type: ignore[override] - return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator] + def __eq__(self, other: PolarsExpr) -> Self: # type: ignore[override] + return self._with_native(self.native.__eq__(extract_native(other))) - def __ne__(self, other: object) -> Self: # type: ignore[override] - return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator] + def __ne__(self, other: PolarsExpr) -> Self: # type: ignore[override] + return self._with_native(self.native.__ne__(extract_native(other))) def __ge__(self, other: Any) -> Self: return self._with_native(self.native.__ge__(extract_native(other))) @@ -223,11 +240,11 @@ def __le__(self, other: Any) -> Self: def __lt__(self, other: Any) -> Self: return self._with_native(self.native.__lt__(extract_native(other))) - def __and__(self, other: PolarsExpr | bool | Any) -> Self: - return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator] + def __and__(self, other: PolarsExpr) -> Self: + return self._with_native(self.native.__and__(extract_native(other))) - def __or__(self, other: PolarsExpr | bool | Any) -> Self: - return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator] + def __or__(self, other: PolarsExpr) -> Self: + return self._with_native(self.native.__or__(extract_native(other))) def __add__(self, other: Any) -> Self: return self._with_native(self.native.__add__(extract_native(other))) @@ -264,46 +281,6 @@ def __invert__(self) -> Self: def cum_count(self, *, reverse: bool) -> Self: return self._with_native(self.native.cum_count(reverse=reverse)) - def is_close( - self, - other: Self | NumericLiteral, - *, - abs_tol: float, - rel_tol: float, - nans_equal: bool, - ) -> Self: - left = self.native - right = other.native if isinstance(other, PolarsExpr) else pl.lit(other) - - if self._backend_version < (1, 32, 0): - lower_bound = right.abs() - tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol) - - # Values are close if abs_diff <= tolerance, and both finite - abs_diff = (left - right).abs() - all_ = pl.all_horizontal - is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite()) - - # Handle infinity cases: infinities are "close" only if they have the same sign - is_same_inf = all_( - left.is_infinite(), right.is_infinite(), (left.sign() == right.sign()) - ) - - # Handle nan cases: - # * nans_equals = True => if both values are NaN, then True - # * nans_equals = False => if any value is NaN, then False - left_is_nan, right_is_nan = left.is_nan(), right.is_nan() - either_nan = left_is_nan | right_is_nan - result = (is_close | is_same_inf) & either_nan.not_() - - if nans_equal: - result = result | (left_is_nan & right_is_nan) - else: - result = left.is_close( - right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal - ) - return self._with_native(result) - def mode(self, *, keep: ModeKeepStrategy) -> Self: result = self.native.mode() return self._with_native(result.first() if keep == "any" else result) @@ -341,7 +318,6 @@ def struct(self) -> PolarsExprStructNamespace: arg_min: Method[Self] arg_true: Method[Self] ceil: Method[Self] - clip: Method[Self] count: Method[Self] cum_max: Method[Self] cum_min: Method[Self] @@ -451,6 +427,22 @@ def zfill(self, width: int) -> PolarsExpr: return self.compliant._with_native(native_result) + def replace( + self, value: PolarsExpr, pattern: str, *, literal: bool, n: int + ) -> PolarsExpr: + value_native = extract_native(value) + return self.compliant._with_native( + self.native.str.replace(pattern, value_native, literal=literal, n=n) + ) + + def replace_all( + self, value: PolarsExpr, pattern: str, *, literal: bool + ) -> PolarsExpr: + value_native = extract_native(value) + return self.compliant._with_native( + self.native.str.replace_all(pattern, value_native, literal=literal) + ) + class PolarsExprCatNamespace( PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr] diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 8f257b36bd..ac8da364be 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -5,7 +5,6 @@ import polars as pl -from narwhals._expression_parsing import is_expr, is_series from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype @@ -19,21 +18,11 @@ from typing_extensions import TypeIs - from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen + from narwhals._compliant import CompliantSelectorNamespace from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame from narwhals._polars.typing import FrameT from narwhals._utils import Version, _LimitedContext - from narwhals.expr import Expr - from narwhals.series import Series - from narwhals.typing import ( - Into1DArray, - IntoDType, - IntoSchema, - NonNestedLiteral, - TimeUnit, - _1DArray, - _2DArray, - ) + from narwhals.typing import Into1DArray, IntoDType, IntoSchema, TimeUnit, _2DArray class PolarsNamespace: @@ -45,8 +34,6 @@ class PolarsNamespace: min_horizontal: Method[PolarsExpr] max_horizontal: Method[PolarsExpr] - when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]] - _implementation: Implementation = Implementation.POLARS _version: Version @@ -84,25 +71,6 @@ def _expr(self) -> type[PolarsExpr]: def _series(self) -> type[PolarsSeries]: return PolarsSeries - def parse_into_expr( - self, - data: Expr | NonNestedLiteral | Series[pl.Series] | _1DArray, - /, - *, - str_as_lit: bool, - ) -> PolarsExpr | None: - if data is None: - # NOTE: To avoid `pl.lit(None)` failing this `None` check - # https://github.com/pola-rs/polars/blob/58dd8e5770f16a9bef9009a1c05f00e15a5263c7/py-polars/polars/expr/expr.py#L2870-L2872 - return data - if is_expr(data): - expr = data._to_compliant_expr(self) - assert isinstance(expr, self._expr) # noqa: S101 - return expr - if isinstance(data, str) and not str_as_lit: - return self.col(data) - return self.lit(data.to_native() if is_series(data) else data, None) - def is_native(self, obj: Any) -> TypeIs[pl.DataFrame | pl.LazyFrame | pl.Series]: return isinstance(obj, (pl.DataFrame, pl.LazyFrame, pl.Series)) @@ -145,7 +113,7 @@ def from_numpy( @requires.backend_version( (1, 0, 0), "Please use `col` for columns selection instead." ) - def nth(self, *indices: int) -> PolarsExpr: + def nth(self, indices: Sequence[int]) -> PolarsExpr: return self._expr(pl.nth(*indices), version=self._version) def len(self) -> PolarsExpr: @@ -230,6 +198,22 @@ def concat_str( version=self._version, ) + def when_then( + self, when: PolarsExpr, then: PolarsExpr, otherwise: PolarsExpr | None = None + ) -> PolarsExpr: + if otherwise is None: + (when_native, then_native), _ = extract_args_kwargs((when, then), {}) + return self._expr( + pl.when(when_native).then(then_native), version=self._version + ) + (when_native, then_native, otherwise_native), _ = extract_args_kwargs( + (when, then, otherwise), {} + ) + return self._expr( + pl.when(when_native).then(then_native).otherwise(otherwise_native), + version=self._version, + ) + # NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`) # 1. Others have lots of private stuff for code reuse # i. None of that is useful here diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 5d5a799d97..f95ca504f6 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -43,7 +43,6 @@ ModeKeepStrategy, MultiIndexSelector, NonNestedLiteral, - NumericLiteral, PythonLiteral, _1DArray, ) @@ -289,6 +288,19 @@ def cast(self, dtype: IntoDType) -> Self: dtype_pl = narwhals_to_native_dtype(dtype, self._version) return self._with_native(self.native.cast(dtype_pl)) + def clip(self, lower_bound: PolarsSeries, upper_bound: PolarsSeries) -> Self: + return self._with_native( + self.native.clip(extract_native(lower_bound), extract_native(upper_bound)) + ) + + def clip_lower(self, lower_bound: PolarsSeries) -> Self: + return self._with_native(self.native.clip(extract_native(lower_bound))) + + def clip_upper(self, upper_bound: PolarsSeries) -> Self: + return self._with_native( + self.native.clip(upper_bound=extract_native(upper_bound)) + ) + @requires.backend_version((1,)) def replace_strict( self, @@ -510,30 +522,6 @@ def __contains__(self, other: Any) -> bool: except Exception as e: # noqa: BLE001 raise catch_polars_exception(e) from None - def is_close( - self, - other: Self | NumericLiteral, - *, - abs_tol: float, - rel_tol: float, - nans_equal: bool, - ) -> PolarsSeries: - if self._backend_version < (1, 32, 0): - name = self.name - ns = self.__narwhals_namespace__() - other_expr = ( - ns.lit(other.native, None) if isinstance(other, PolarsSeries) else other - ) - expr = ns.col(name).is_close( - other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal - ) - return self.to_frame().select(expr).get_column(name) - other_series = other.native if isinstance(other, PolarsSeries) else other - result = self.native.is_close( - other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal - ) - return self._with_native(result) - def mode(self, *, keep: ModeKeepStrategy) -> Self: result = self.native.mode() return self._with_native(result.head(1) if keep == "any" else result) @@ -702,7 +690,6 @@ def struct(self) -> PolarsSeriesStructNamespace: arg_min: Method[int] arg_true: Method[Self] ceil: Method[Self] - clip: Method[Self] count: Method[int] cum_max: Method[Self] cum_min: Method[Self] @@ -800,6 +787,22 @@ def zfill(self, width: int) -> PolarsSeries: ns = self.__narwhals_namespace__() return self.to_frame().select(ns.col(name).str.zfill(width)).get_column(name) + def replace( + self, value: PolarsSeries, pattern: str, *, literal: bool, n: int + ) -> PolarsSeries: + value_native = extract_native(value) + return self.compliant._with_native( + self.native.str.replace(pattern, value_native, literal=literal, n=n) # type: ignore[arg-type] + ) + + def replace_all( + self, value: PolarsSeries, pattern: str, *, literal: bool + ) -> PolarsSeries: + value_native = extract_native(value) + return self.compliant._with_native( + self.native.str.replace_all(pattern, value_native, literal=literal) # type: ignore[arg-type] + ) + class PolarsSeriesCatNamespace( PolarsSeriesNamespace, PolarsCatNamespace[PolarsSeries, pl.Series] diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 4e4acb71b4..04e17f3ac0 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -604,10 +604,4 @@ def _from_compliant_dataframe( validate_backend_version=True, ) - gather_every = not_implemented.deprecated( - "`LazyFrame.gather_every` is deprecated and will be removed in a future version." - ) join_asof = not_implemented() - tail = not_implemented.deprecated( - "`LazyFrame.tail` is deprecated and will be removed in a future version." - ) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 5d584956d2..e7a8ba7287 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -3,7 +3,6 @@ import operator from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast -from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace from narwhals._spark_like.expr_list import SparkLikeExprListNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace @@ -41,7 +40,7 @@ from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals._utils import _LimitedContext - from narwhals.typing import FillNullStrategy, IntoDType, NonNestedLiteral, RankMethod + from narwhals.typing import FillNullStrategy, IntoDType, RankMethod NativeRankMethod: TypeAlias = Literal["rank", "dense_rank", "row_number"] SparkWindowFunction = WindowFunction[SparkLikeLazyFrame, Column] @@ -64,7 +63,6 @@ def __init__( self._alias_output_names = alias_output_names self._version = version self._implementation = implementation - self._metadata: ExprMetadata | None = None self._window_function: SparkWindowFunction | None = window_function _REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = { @@ -112,9 +110,7 @@ def _last(self, expr: Column, *order_by: str) -> Column: # pragma: no cover msg = "`last` is not supported for PySpark." raise NotImplementedError(msg) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: - if kind is ExprKind.LITERAL: - return self + def broadcast(self) -> Self: return self.over([self._F.lit(1)], []) @property @@ -210,19 +206,19 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: implementation=context._implementation, ) - def __truediv__(self, other: SparkLikeExpr) -> Self: + def __truediv__(self, other: Self) -> Self: def _truediv(expr: Column, other: Column) -> Column: return true_divide(self._F, expr, other) return self._with_binary(_truediv, other) - def __rtruediv__(self, other: SparkLikeExpr) -> Self: + def __rtruediv__(self, other: Self) -> Self: def _rtruediv(expr: Column, other: Column) -> Column: return true_divide(self._F, other, expr) return self._with_binary(_rtruediv, other).alias("literal") - def __floordiv__(self, other: SparkLikeExpr) -> Self: + def __floordiv__(self, other: Self) -> Self: def _floordiv(expr: Column, other: Column) -> Column: F = self._F return F.when( @@ -231,7 +227,7 @@ def _floordiv(expr: Column, other: Column) -> Column: return self._with_binary(_floordiv, other) - def __rfloordiv__(self, other: SparkLikeExpr) -> Self: + def __rfloordiv__(self, other: Self) -> Self: def _rfloordiv(expr: Column, other: Column) -> Column: F = self._F return F.when( @@ -328,10 +324,7 @@ def _is_nan(expr: Column) -> Column: return self._with_elementwise(_is_nan) def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: if strategy is not None: @@ -359,6 +352,7 @@ def _fill_with_strategy( def _fill_constant(expr: Column, value: Column) -> Column: return self._F.ifnull(expr, value) + assert value is not None # noqa: S101 return self._with_elementwise(_fill_constant, value=value) @property diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 0486ae95aa..c660b67298 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -19,13 +19,13 @@ true_divide, ) from narwhals._sql.namespace import SQLNamespace -from narwhals._sql.when_then import SQLThen, SQLWhen if TYPE_CHECKING: from collections.abc import Iterable from sqlframe.base.column import Column + from narwhals._compliant.window import WindowInputs from narwhals._spark_like.dataframe import SQLFrameDataFrame # noqa: F401 from narwhals._utils import Implementation, Version from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral, PythonLiteral @@ -92,7 +92,7 @@ def _coalesce(self, *exprs: Column) -> Column: return self._F.coalesce(*exprs) def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> SparkLikeExpr: - def _lit(df: SparkLikeLazyFrame) -> list[Column]: + def func(df: SparkLikeLazyFrame) -> list[Column]: column = df._F.lit(value) if dtype: native_dtype = narwhals_to_native_dtype( @@ -102,8 +102,14 @@ def _lit(df: SparkLikeLazyFrame) -> list[Column]: return [column] + def window_func( + df: SparkLikeLazyFrame, _window_inputs: WindowInputs[Column] + ) -> list[Column]: + return func(df) + return self._expr( - call=_lit, + func, + window_func, evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, version=self._version, @@ -190,17 +196,3 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: version=self._version, implementation=self._implementation, ) - - def when(self, predicate: SparkLikeExpr) -> SparkLikeWhen: - return SparkLikeWhen.from_expr(predicate, context=self) - - -class SparkLikeWhen(SQLWhen[SparkLikeLazyFrame, "Column", SparkLikeExpr]): - @property - def _then(self) -> type[SparkLikeThen]: - return SparkLikeThen - - -class SparkLikeThen( - SQLThen[SparkLikeLazyFrame, "Column", SparkLikeExpr], SparkLikeExpr -): ... diff --git a/narwhals/_sql/dataframe.py b/narwhals/_sql/dataframe.py index 356a77373f..f4a03dc541 100644 --- a/narwhals/_sql/dataframe.py +++ b/narwhals/_sql/dataframe.py @@ -10,6 +10,7 @@ ) from narwhals._translate import ToNarwhalsT_co from narwhals._utils import check_columns_exist +from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: from collections.abc import Sequence @@ -34,12 +35,18 @@ def _evaluate_window_expr( window_inputs: WindowInputs[NativeExprT], ) -> NativeExprT: result = expr.window_function(self, window_inputs) - assert len(result) == 1 # debug assertion # noqa: S101 + if len(result) != 1: # pragma: no cover + msg = "multi-output expressions not allowed in this context" + raise MultiOutputExpressionError(msg) return result[0] - def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any: + def _evaluate_single_output_expr( + self, expr: SQLExpr[Self, NativeExprT], / + ) -> NativeExprT: result = expr(self) - assert len(result) == 1 # debug assertion # noqa: S101 + if len(result) != 1: # pragma: no cover + msg = "multi-output expressions not allowed in this context" + raise MultiOutputExpressionError(msg) return result[0] def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index 1f5b49d9cd..670f4583ad 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -20,22 +20,15 @@ from narwhals._utils import Implementation, Version, extend_bool, not_implemented if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Sequence - from typing_extensions import Self, TypeIs + from typing_extensions import Self from narwhals._compliant.typing import AliasNames, WindowFunction - from narwhals._expression_parsing import ExprMetadata from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace from narwhals._sql.expr_str import SQLExprStringNamespace from narwhals._sql.namespace import SQLNamespace - from narwhals.typing import ( - ModeKeepStrategy, - NumericLiteral, - PythonLiteral, - RankMethod, - TemporalLiteral, - ) + from narwhals.typing import ModeKeepStrategy, PythonLiteral, RankMethod class SQLExpr(LazyExpr[SQLLazyFrameT, NativeExprT], Protocol[SQLLazyFrameT, NativeExprT]): @@ -44,7 +37,6 @@ class SQLExpr(LazyExpr[SQLLazyFrameT, NativeExprT], Protocol[SQLLazyFrameT, Nati _alias_output_names: AliasNames | None _version: Version _implementation: Implementation - _metadata: ExprMetadata | None _window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None def __init__( @@ -66,14 +58,12 @@ def __narwhals_namespace__( ) -> SQLNamespace[SQLLazyFrameT, Self, Any, NativeExprT]: ... def _callable_to_eval_series( - self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any + self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self ) -> EvalSeries[SQLLazyFrameT, NativeExprT]: def func(df: SQLLazyFrameT) -> list[NativeExprT]: native_series_list = self(df) other_native_series = { - key: df._evaluate_expr(value) - if self._is_expr(value) - else self._lit(value) + key: df._evaluate_single_output_expr(value) for key, value in expressifiable_args.items() } return [ @@ -84,7 +74,7 @@ def func(df: SQLLazyFrameT) -> list[NativeExprT]: return func def _push_down_window_function( - self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any + self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self ) -> WindowFunction[SQLLazyFrameT, NativeExprT]: def window_f( df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT] @@ -97,8 +87,6 @@ def window_f( native_series_list = self.window_function(df, window_inputs) other_native_series = { key: df._evaluate_window_expr(value, window_inputs) - if self._is_expr(value) - else self._lit(value) for key, value in expressifiable_args.items() } return [ @@ -125,7 +113,7 @@ def _with_callable( call: Callable[..., NativeExprT], window_func: WindowFunction[SQLLazyFrameT, NativeExprT] | None = None, /, - **expressifiable_args: Self | Any, + **expressifiable_args: Self, ) -> Self: return self.__class__( self._callable_to_eval_series(call, **expressifiable_args), @@ -137,7 +125,7 @@ def _with_callable( ) def _with_elementwise( - self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any + self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self ) -> Self: return self.__class__( self._callable_to_eval_series(call, **expressifiable_args), @@ -148,7 +136,7 @@ def _with_elementwise( implementation=self._implementation, ) - def _with_binary(self, op: Callable[..., NativeExprT], other: Self | Any) -> Self: + def _with_binary(self, op: Callable[..., NativeExprT], other: Self) -> Self: return self.__class__( self._callable_to_eval_series(op, other=other), self._push_down_window_function(op, other=other), @@ -183,8 +171,7 @@ def default_window_func( ) -> Sequence[NativeExprT]: assert not inputs.order_by # noqa: S101 return [ - self._window_expression(expr, inputs.partition_by, inputs.order_by) - for expr in self(df) + self._window_expression(expr, inputs.partition_by) for expr in self(df) ] return self._window_function or default_window_func @@ -303,10 +290,6 @@ def func( return func - @classmethod - def _is_expr(cls, obj: Self | Any) -> TypeIs[Self]: - return hasattr(obj, "__narwhals_expr__") - @property def _backend_version(self) -> tuple[int, ...]: return self._implementation._backend_version() @@ -316,19 +299,16 @@ def _alias_native(cls, expr: NativeExprT, name: str, /) -> NativeExprT: ... @classmethod def _from_elementwise_horizontal_op( - cls, func: Callable[[Iterable[NativeExprT]], NativeExprT], *exprs: Self + cls, func: Callable[[list[NativeExprT]], NativeExprT], *exprs: Self ) -> Self: def call(df: SQLLazyFrameT) -> Sequence[NativeExprT]: - cols = (col for _expr in exprs for col in _expr(df)) - return [func(cols)] + return [func([e for expr in exprs for e in expr(df)])] def window_function( df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT] ) -> Sequence[NativeExprT]: - cols = ( - col for _expr in exprs for col in _expr.window_function(df, window_inputs) - ) - return [func(cols)] + lst = [e for expr in exprs for e in expr.window_function(df, window_inputs)] + return [func(lst)] context = exprs[0] return cls( @@ -351,7 +331,6 @@ def _is_multi_output_unnamed(self) -> bool: nw.all().sum(). """ - assert self._metadata is not None # noqa: S101 return self._metadata.expansion_kind.is_multi_unnamed() # Binary @@ -520,6 +499,7 @@ def f(expr: NativeExprT) -> NativeExprT: def window_f( df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT] ) -> Sequence[NativeExprT]: + assert not inputs.order_by # noqa: S101 return [ self._coalesce( self._window_expression( @@ -563,32 +543,30 @@ def window_f( def abs(self) -> Self: return self._with_elementwise(lambda expr: self._function("abs", expr)) - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: - def _clip_lower(expr: NativeExprT, lower_bound: Any) -> NativeExprT: - return self._function("greatest", expr, lower_bound) - - def _clip_upper(expr: NativeExprT, upper_bound: Any) -> NativeExprT: - return self._function("least", expr, upper_bound) - - def _clip_both( - expr: NativeExprT, lower_bound: Any, upper_bound: Any + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: + def _clip( + expr: NativeExprT, lower_bound: NativeExprT, upper_bound: NativeExprT ) -> NativeExprT: return self._function( "greatest", self._function("least", expr, upper_bound), lower_bound ) - if lower_bound is None: - return self._with_elementwise(_clip_upper, upper_bound=upper_bound) - if upper_bound is None: - return self._with_elementwise(_clip_lower, lower_bound=lower_bound) return self._with_elementwise( - _clip_both, lower_bound=lower_bound, upper_bound=upper_bound + _clip, lower_bound=lower_bound, upper_bound=upper_bound ) + def clip_lower(self, lower_bound: Self) -> Self: + def _clip(expr: NativeExprT, lower_bound: NativeExprT) -> NativeExprT: + return self._function("greatest", expr, lower_bound) + + return self._with_elementwise(_clip, lower_bound=lower_bound) + + def clip_upper(self, upper_bound: Self) -> Self: + def _clip(expr: NativeExprT, upper_bound: NativeExprT) -> NativeExprT: + return self._function("least", expr, upper_bound) + + return self._with_elementwise(_clip, upper_bound=upper_bound) + def is_null(self) -> Self: return self._with_elementwise(lambda expr: self._function("isnull", expr)) @@ -895,4 +873,5 @@ def str(self) -> SQLExprStringNamespace[Self]: ... def dt(self) -> SQLExprDateTimeNamesSpace[Self]: ... drop_nulls = not_implemented() # type: ignore[misc] + filter = not_implemented() # type: ignore[misc] unique = not_implemented() # type: ignore[misc] diff --git a/narwhals/_sql/expr_dt.py b/narwhals/_sql/expr_dt.py index 85b65aaf05..8c660bc500 100644 --- a/narwhals/_sql/expr_dt.py +++ b/narwhals/_sql/expr_dt.py @@ -1,16 +1,19 @@ from __future__ import annotations -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic from narwhals._compliant import LazyExprNamespace from narwhals._compliant.any_namespace import DateTimeNamespace from narwhals._sql.typing import SQLExprT +if TYPE_CHECKING: + from narwhals._compliant.expr import NativeExpr + class SQLExprDateTimeNamesSpace( LazyExprNamespace[SQLExprT], DateTimeNamespace[SQLExprT], Generic[SQLExprT] ): - def _function(self, name: str, *args: Any) -> SQLExprT: + def _function(self, name: str, *args: Any) -> NativeExpr: return self.compliant._function(name, *args) # type: ignore[no-any-return] def year(self) -> SQLExprT: diff --git a/narwhals/_sql/expr_str.py b/narwhals/_sql/expr_str.py index 7e82ed30ec..db43531823 100644 --- a/narwhals/_sql/expr_str.py +++ b/narwhals/_sql/expr_str.py @@ -1,26 +1,32 @@ from __future__ import annotations -from typing import Any, Generic +import operator +from typing import TYPE_CHECKING, Any, Generic from narwhals._compliant import LazyExprNamespace from narwhals._compliant.any_namespace import StringNamespace from narwhals._sql.typing import SQLExprT +if TYPE_CHECKING: + from narwhals._compliant.expr import NativeExpr + class SQLExprStringNamespace( LazyExprNamespace[SQLExprT], StringNamespace[SQLExprT], Generic[SQLExprT] ): - def _lit(self, value: Any) -> SQLExprT: + def _lit(self, value: Any) -> NativeExpr: return self.compliant._lit(value) # type: ignore[no-any-return] - def _function(self, name: str, *args: Any) -> SQLExprT: + def _function(self, name: str, *args: Any) -> NativeExpr: return self.compliant._function(name, *args) # type: ignore[no-any-return] - def _when(self, condition: Any, value: Any, otherwise: Any | None = None) -> SQLExprT: + def _when( + self, condition: Any, value: Any, otherwise: Any | None = None + ) -> NativeExpr: return self.compliant._when(condition, value, otherwise) # type: ignore[no-any-return] def contains(self, pattern: str, *, literal: bool) -> SQLExprT: - def func(expr: Any) -> Any: + def func(expr: NativeExpr) -> NativeExpr: if literal: return self._function("contains", expr, self._lit(pattern)) return self._function("regexp_matches", expr, self._lit(pattern)) @@ -37,22 +43,12 @@ def len_chars(self) -> SQLExprT: lambda expr: self._function("length", expr) ) - def replace_all( - self, pattern: str, value: str | SQLExprT, *, literal: bool - ) -> SQLExprT: + def replace_all(self, value: SQLExprT, pattern: str, *, literal: bool) -> SQLExprT: fname: str = "replace" if literal else "regexp_replace" options: list[Any] = [] if not literal and self.compliant._implementation.is_duckdb(): options = [self._lit("g")] - - if isinstance(value, str): - return self.compliant._with_elementwise( - lambda expr: self._function( - fname, expr, self._lit(pattern), self._lit(value), *options - ) - ) - return self.compliant._with_elementwise( lambda expr, value: self._function( fname, expr, self._lit(pattern), value, *options @@ -61,11 +57,11 @@ def replace_all( ) def slice(self, offset: int, length: int | None) -> SQLExprT: - def func(expr: SQLExprT) -> SQLExprT: + def func(expr: NativeExpr) -> NativeExpr: col_length = self._function("length", expr) _offset = ( - col_length + self._lit(offset + 1) + operator.add(col_length, self._lit(offset + 1)) if offset < 0 else self._lit(offset + 1) ) @@ -109,7 +105,7 @@ def zfill(self, width: int) -> SQLExprT: # There is no built-in zfill function, so we need to implement it manually # using string manipulation functions. - def func(expr: Any) -> Any: + def func(expr: NativeExpr) -> NativeExpr: less_than_width = self._function("length", expr) < self._lit(width) zero, hyphen, plus = self._lit("0"), self._lit("-"), self._lit("+") @@ -120,10 +116,10 @@ def func(expr: Any) -> Any: "lpad", substring, self._lit(width - 1), zero ) return self._when( - starts_with_minus & less_than_width, + operator.and_(starts_with_minus, less_than_width), self._function("concat", hyphen, padded_substring), self._when( - starts_with_plus & less_than_width, + operator.and_(starts_with_plus, less_than_width), self._function("concat", plus, padded_substring), self._when( less_than_width, diff --git a/narwhals/_sql/namespace.py b/narwhals/_sql/namespace.py index dee8a7e470..94f61d7c65 100644 --- a/narwhals/_sql/namespace.py +++ b/narwhals/_sql/namespace.py @@ -71,3 +71,18 @@ def func(cols: Iterable[NativeExprT]) -> NativeExprT: return self._coalesce(*cols) return self._expr._from_elementwise_horizontal_op(func, *exprs) + + def when_then( + self, predicate: SQLExprT, then: SQLExprT, otherwise: SQLExprT | None = None + ) -> SQLExprT: + def func(cols: list[NativeExprT]) -> NativeExprT: + return self._when(cols[1], cols[0]) + + def func_with_otherwise(cols: list[NativeExprT]) -> NativeExprT: + return self._when(cols[1], cols[0], cols[2]) + + if otherwise is None: + return self._expr._from_elementwise_horizontal_op(func, then, predicate) + return self._expr._from_elementwise_horizontal_op( + func_with_otherwise, then, predicate, otherwise + ) diff --git a/narwhals/_sql/when_then.py b/narwhals/_sql/when_then.py deleted file mode 100644 index 11c5bf5e20..0000000000 --- a/narwhals/_sql/when_then.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Protocol - -from narwhals._compliant.typing import NativeExprT -from narwhals._compliant.when_then import CompliantThen, CompliantWhen -from narwhals._sql.typing import SQLExprT, SQLLazyFrameT - -if TYPE_CHECKING: - from collections.abc import Sequence - - from typing_extensions import Self - - from narwhals._compliant.typing import WindowFunction - from narwhals._compliant.when_then import IntoExpr - from narwhals._compliant.window import WindowInputs - from narwhals._utils import _LimitedContext - - -class SQLWhen( - CompliantWhen[SQLLazyFrameT, NativeExprT, SQLExprT], - Protocol[SQLLazyFrameT, NativeExprT, SQLExprT], -): - @property - def _then(self) -> type[SQLThen[SQLLazyFrameT, NativeExprT, SQLExprT]]: ... - - def __call__(self, df: SQLLazyFrameT) -> Sequence[NativeExprT]: - is_expr = self._condition._is_expr - when = df.__narwhals_namespace__()._when - lit = df.__narwhals_namespace__()._lit - condition = df._evaluate_expr(self._condition) - then_ = self._then_value - then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_) - other_ = self._otherwise_value - if other_ is None: - result = when(condition, then) - else: - otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_) - result = when(condition, then).otherwise(otherwise) - return [result] - - @classmethod - def from_expr(cls, condition: SQLExprT, /, *, context: _LimitedContext) -> Self: - obj = cls.__new__(cls) - obj._condition = condition - obj._then_value = None - obj._otherwise_value = None - obj._implementation = context._implementation - obj._version = context._version - return obj - - def _window_function( - self, df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT] - ) -> Sequence[NativeExprT]: - when = df.__narwhals_namespace__()._when - lit = df.__narwhals_namespace__()._lit - is_expr = self._condition._is_expr - condition = self._condition.window_function(df, window_inputs)[0] - then_ = self._then_value - then = ( - then_.window_function(df, window_inputs)[0] if is_expr(then_) else lit(then_) - ) - - other_ = self._otherwise_value - if other_ is None: - result = when(condition, then) - else: - other = ( - other_.window_function(df, window_inputs)[0] - if is_expr(other_) - else lit(other_) - ) - result = when(condition, then).otherwise(other) - return [result] - - -class SQLThen( - CompliantThen[ - SQLLazyFrameT, - NativeExprT, - SQLExprT, - SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT], - ], - Protocol[SQLLazyFrameT, NativeExprT, SQLExprT], -): - _window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None - - @classmethod - def from_when( - cls, - when: SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT], - then: IntoExpr[NativeExprT, SQLExprT], - /, - ) -> Self: - when._then_value = then - obj = cls.__new__(cls) - obj._call = when - obj._window_function = when._window_function - obj._when_value = when - obj._evaluate_output_names = getattr( - then, "_evaluate_output_names", lambda _df: ["literal"] - ) - obj._alias_output_names = getattr(then, "_alias_output_names", None) - obj._implementation = when._implementation - obj._version = when._version - return obj diff --git a/narwhals/_utils.py b/narwhals/_utils.py index e2ea550bea..c7066a0d5f 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -66,14 +66,7 @@ TypeIs, ) - from narwhals._compliant import ( - CompliantExpr, - CompliantExprT, - CompliantFrameT, - CompliantSeriesOrNativeExprT_co, - CompliantSeriesT, - NativeSeriesT_co, - ) + from narwhals._compliant import CompliantExprT, CompliantSeriesT, NativeSeriesT_co from narwhals._compliant.any_namespace import NamespaceAccessor from narwhals._compliant.typing import ( Accessor, @@ -1583,12 +1576,6 @@ def is_compliant_series_int( return is_compliant_series(obj) and obj.dtype.is_integer() -def is_compliant_expr( - obj: CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | Any, -) -> TypeIs[CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]]: - return hasattr(obj, "__narwhals_expr__") - - def _is_namespace_accessor(obj: _IntoContext) -> TypeIs[NamespaceAccessor[_FullContext]]: # NOTE: Only `compliant` has false positives **internally** # - https://github.com/narwhals-dev/narwhals/blob/cc69bac35eb8c81a1106969c49bfba9fd569b856/narwhals/_compliant/group_by.py#L44-L49 @@ -1785,7 +1772,7 @@ def __call__(self, *args: Any, **kwds: Any) -> Any: return self.__get__("raise") @classmethod - def deprecated(cls, message: LiteralString, /) -> Self: + def deprecated(cls, message: LiteralString, /) -> Self: # pragma: no cover """Alt constructor, wraps with `@deprecated`. Arguments: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 3fc3172241..74f1c68b03 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import abstractmethod +from functools import partial from itertools import chain from typing import ( TYPE_CHECKING, @@ -17,9 +18,8 @@ from narwhals._exceptions import issue_warning from narwhals._expression_parsing import ( - ExprKind, + _parse_into_expr, check_expressions_preserve_length, - is_into_expr_eager, is_scalar_like, ) from narwhals._typing import Arrow, Pandas, _LazyAllowedImpl, _LazyFrameCollectImpl @@ -48,7 +48,6 @@ from narwhals.dependencies import is_numpy_array_2d, is_pyarrow_table from narwhals.exceptions import ( ColumnNotFoundError, - InvalidIntoExprError, InvalidOperationError, PerformanceWarning, ) @@ -69,7 +68,8 @@ from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias from narwhals._compliant import CompliantDataFrame, CompliantLazyFrame - from narwhals._compliant.typing import CompliantExprAny, EagerNamespaceAny + from narwhals._compliant.typing import CompliantExprAny + from narwhals._expression_parsing import ExprMetadata from narwhals._translate import IntoArrowTable from narwhals._typing import EagerAllowed, IntoBackend, LazyAllowed, Polars from narwhals.group_by import GroupBy, LazyGroupBy @@ -144,24 +144,23 @@ def _with_compliant(self, df: Any) -> Self: def _flatten_and_extract( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr - ) -> tuple[list[CompliantExprAny], list[ExprKind]]: + ) -> list[CompliantExprAny]: # Process `args` and `kwargs`, extracting underlying objects as we go. # NOTE: Strings are interpreted as column names. out_exprs = [] - out_kinds = [] - for expr in flatten(exprs): - compliant_expr = self._extract_compliant(expr) - out_exprs.append(compliant_expr) - out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False)) - for alias, expr in named_exprs.items(): - compliant_expr = self._extract_compliant(expr).alias(alias) - out_exprs.append(compliant_expr) - out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False)) - return out_exprs, out_kinds - - @abstractmethod - def _extract_compliant(self, arg: Any) -> Any: - raise NotImplementedError + ns = self.__narwhals_namespace__() + parse = partial( + _parse_into_expr, backend=self._compliant._implementation, allow_literal=False + ) + all_exprs = chain( + (parse(x) for x in flatten(exprs)), + (parse(expr).alias(alias) for alias, expr in named_exprs.items()), + ) + for expr in all_exprs: + ce = expr._to_compliant_expr(ns) + out_exprs.append(ce) + self._validate_metadata(ce._metadata) + return out_exprs def _extract_compliant_frame(self, other: Self | Any, /) -> Any: if isinstance(other, type(self)): @@ -172,6 +171,10 @@ def _extract_compliant_frame(self, other: Self | Any, /) -> Any: def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: return check_columns_exist(subset, available=self.columns) + @abstractmethod + def _validate_metadata(self, metadata: ExprMetadata) -> None: # pragma: no cover + pass + @property def schema(self) -> Schema: return Schema(self._compliant_frame.schema.items()) @@ -200,10 +203,12 @@ def columns(self) -> list[str]: def with_columns( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr ) -> Self: - compliant_exprs, kinds = self._flatten_and_extract(*exprs, **named_exprs) + compliant_exprs = self._flatten_and_extract(*exprs, **named_exprs) compliant_exprs = [ - compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip_strict(compliant_exprs, kinds) + compliant_expr.broadcast() + if is_scalar_like(compliant_expr) + else compliant_expr + for compliant_expr in compliant_exprs ] return self._with_compliant(self._compliant_frame.with_columns(*compliant_exprs)) @@ -222,12 +227,14 @@ def select( if error := self._check_columns_exist(flat_exprs): raise error from e raise - compliant_exprs, kinds = self._flatten_and_extract(*flat_exprs, **named_exprs) - if compliant_exprs and all(is_scalar_like(kind) for kind in kinds): + compliant_exprs = self._flatten_and_extract(*flat_exprs, **named_exprs) + if compliant_exprs and all(is_scalar_like(x) for x in compliant_exprs): return self._with_compliant(self._compliant_frame.aggregate(*compliant_exprs)) compliant_exprs = [ - compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip_strict(compliant_exprs, kinds) + compliant_expr.broadcast() + if is_scalar_like(compliant_expr) + else compliant_expr + for compliant_expr in compliant_exprs ] return self._with_compliant(self._compliant_frame.select(*compliant_exprs)) @@ -249,11 +256,11 @@ def filter( from narwhals.functions import col flat_predicates = flatten(predicates) - check_expressions_preserve_length(*flat_predicates, function_name="filter") plx = self.__narwhals_namespace__() - compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates) - compliant_constraints = ( - (col(name) == v)._to_compliant_expr(plx) for name, v in constraints.items() + compliant_predicates = self._flatten_and_extract(*flat_predicates) + check_expressions_preserve_length(*compliant_predicates, function_name="filter") + compliant_constraints = self._flatten_and_extract( + *[col(name) == v for name, v in constraints.items()] ) predicate = plx.all_horizontal( *chain(compliant_predicates, compliant_constraints), ignore_nulls=False @@ -473,12 +480,6 @@ class DataFrame(BaseFrame[DataFrameT]): def _compliant(self) -> CompliantDataFrame[Any, Any, DataFrameT, Self]: return self._compliant_frame - def _extract_compliant(self, arg: Any) -> Any: - if is_into_expr_eager(arg): - plx: EagerNamespaceAny = self.__narwhals_namespace__() - return plx.parse_into_expr(arg, str_as_lit=False) - raise InvalidIntoExprError.from_invalid_type(type(arg)) - @property def _series(self) -> type[Series[Any]]: return Series @@ -487,6 +488,10 @@ def _series(self) -> type[Series[Any]]: def _lazyframe(self) -> type[LazyFrame[Any]]: return LazyFrame + def _validate_metadata(self, metadata: ExprMetadata) -> None: + # all is valid in eager case. + pass + def __init__(self, df: Any, *, level: Literal["full", "lazy", "interchange"]) -> None: self._level: Literal["full", "lazy", "interchange"] = level self._compliant_frame: CompliantDataFrame[Any, Any, DataFrameT, Self] @@ -1788,8 +1793,10 @@ def group_by( k if is_expr else col(k) for k, is_expr in zip_strict(flat_keys, key_is_expr_or_series) ] - expr_flat_keys, _kinds = self._flatten_and_extract(*_keys) - check_expressions_preserve_length(*_keys, function_name="DataFrame.group_by") + expr_flat_keys = self._flatten_and_extract(*_keys) + check_expressions_preserve_length( + *expr_flat_keys, function_name="DataFrame.group_by" + ) return GroupBy(self, expr_flat_keys, drop_null_keys=drop_null_keys) def sort( @@ -2352,45 +2359,35 @@ class LazyFrame(BaseFrame[LazyFrameT]): def _compliant(self) -> CompliantLazyFrame[Any, LazyFrameT, Self]: return self._compliant_frame - def _extract_compliant(self, arg: Any) -> Any: - from narwhals.expr import Expr - from narwhals.series import Series - - if isinstance(arg, Series): # pragma: no cover - msg = "Binary operations between Series and LazyFrame are not supported." - raise TypeError(msg) - if isinstance(arg, (Expr, str)): - if isinstance(arg, Expr): - if arg._metadata.n_orderable_ops: - msg = ( - "Order-dependent expressions are not supported for use in LazyFrame.\n\n" - "Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n" - "For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n" - "`'date'` which orders your data, then replace:\n\n" - " nw.col('price').cum_sum()\n\n" - " with:\n\n" - " nw.col('price').cum_sum().over(order_by='date')\n" - " ^^^^^^^^^^^^^^^^^^^^^^\n\n" - "See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/." - ) - raise InvalidOperationError(msg) - if arg._metadata.is_filtration: - msg = ( - "Length-changing expressions are not supported for use in LazyFrame, unless\n" - "followed by an aggregation.\n\n" - "Hints:\n" - "- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n" - "- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n" - " use `lf.select(nw.col('a').drop_nulls().sum())\n" - ) - raise InvalidOperationError(msg) - return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False) - raise InvalidIntoExprError.from_invalid_type(type(arg)) - @property def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame + def _validate_metadata(self, metadata: ExprMetadata) -> None: + if metadata.n_orderable_ops > 0: + msg = ( + "Order-dependent expressions are not supported for use in LazyFrame.\n\n" + "Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n" + "For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n" + "`'date'` which orders your data, then replace:\n\n" + " nw.col('price').cum_sum()\n\n" + " with:\n\n" + " nw.col('price').cum_sum().over(order_by='date')\n" + " ^^^^^^^^^^^^^^^^^^^^^^\n\n" + "See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/." + ) + raise InvalidOperationError(msg) + if metadata.is_filtration: + msg = ( + "Length-changing expressions are not supported for use in LazyFrame, unless\n" + "followed by an aggregation.\n\n" + "Hints:\n" + "- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n" + "- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n" + " use `lf.select(nw.col('a').drop_nulls().sum())\n" + ) + raise InvalidOperationError(msg) + def __init__(self, df: Any, *, level: Literal["full", "lazy", "interchange"]) -> None: self._level = level self._compliant_frame: CompliantLazyFrame[Any, LazyFrameT, Self] @@ -3021,9 +3018,10 @@ def group_by( _keys = [ k if is_expr else col(k) for k, is_expr in zip_strict(flat_keys, key_is_expr) ] - expr_flat_keys, _kinds = self._flatten_and_extract(*_keys) - check_expressions_preserve_length(*_keys, function_name="LazyFrame.group_by") - + expr_flat_keys = self._flatten_and_extract(*_keys) + check_expressions_preserve_length( + *expr_flat_keys, function_name="LazyFrame.group_by" + ) return LazyGroupBy(self, expr_flat_keys, drop_null_keys=drop_null_keys) def sort( diff --git a/narwhals/expr.py b/narwhals/expr.py index df72bcbfea..887f277cc6 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1,14 +1,14 @@ from __future__ import annotations import math -import operator as op from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, Callable from narwhals._expression_parsing import ( - ExprMetadata, - apply_n_ary_operation, - combine_metadata, + ExprKind, + ExprNode, + evaluate_node, + evaluate_root_node, ) from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten from narwhals.dtypes import _validate_dtype @@ -24,10 +24,11 @@ if TYPE_CHECKING: from typing import NoReturn, TypeVar - from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias + from typing_extensions import Concatenate, ParamSpec, Self from narwhals._compliant import CompliantExpr, CompliantNamespace from narwhals.dtypes import DType + from narwhals.series import Series from narwhals.typing import ( ClosedInterval, FillNullStrategy, @@ -39,74 +40,60 @@ RankMethod, RollingInterpolationMethod, TemporalLiteral, - _1DArray, ) PS = ParamSpec("PS") R = TypeVar("R") - _ToCompliant: TypeAlias = Callable[ - [CompliantNamespace[Any, Any]], CompliantExpr[Any, Any] - ] class Expr: - def __init__(self, to_compliant_expr: _ToCompliant, metadata: ExprMetadata) -> None: - # callable from CompliantNamespace to CompliantExpr - def func(plx: CompliantNamespace[Any, Any]) -> CompliantExpr[Any, Any]: - result = to_compliant_expr(plx) - result._metadata = self._metadata - return result - - self._to_compliant_expr: _ToCompliant = func - self._metadata = metadata - - def _with_elementwise(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_elementwise_op()) - - def _with_aggregation(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_aggregation()) - - def _with_orderable_aggregation( - self, to_compliant_expr: Callable[[Any], Any] - ) -> Self: - return self.__class__( - to_compliant_expr, self._metadata.with_orderable_aggregation() - ) - - def _with_orderable_window(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_orderable_window()) - - def _with_window(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_window()) - - def _with_filtration(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_filtration()) - - def _with_orderable_filtration(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__( - to_compliant_expr, self._metadata.with_orderable_filtration() - ) - - def _with_nary( - self, - n_ary_function: Callable[..., Any], - *args: IntoExpr | NonNestedLiteral | _1DArray, - ) -> Self: - return self.__class__( - lambda plx: apply_n_ary_operation( - plx, n_ary_function, self, *args, str_as_lit=False - ), - combine_metadata( - self, - *args, - str_as_lit=False, - allow_multi_output=False, - to_single_output=False, - ), - ) + def __init__(self, *nodes: ExprNode) -> None: + self._nodes = nodes + + def _to_compliant_expr( + self, ns: CompliantNamespace[Any, Any] + ) -> CompliantExpr[Any, Any]: + nodes = self._nodes + ce = evaluate_root_node(nodes[0], ns) + for node in nodes[1:]: + ce = evaluate_node(ce, node, ns) + return ce + + def _append_node(self, node: ExprNode) -> Self: + return self.__class__(*self._nodes, node) + + def _with_over_node(self, node: ExprNode) -> Self: + # insert `over` before any elementwise operations. + # check "how it works" page in docs for why we do this. + new_nodes = list(self._nodes) + kwargs_no_order_by = { + key: value if key != "order_by" else [] + for (key, value) in node.kwargs.items() + } + node_without_order_by = node._with_kwargs(**kwargs_no_order_by) + n = len(new_nodes) + i = n + while i > 0 and (_node := new_nodes[i - 1]).kind is ExprKind.ELEMENTWISE: + i -= 1 + _node._push_down_over_node_in_place(node, node_without_order_by) + if i == n: + # node could not be pushed down, just append as-is + new_nodes.append(node) + return self.__class__(*new_nodes) + if node.kwargs["order_by"] and any(node.is_orderable() for node in new_nodes[:i]): + new_nodes.insert(i, node) + elif node.kwargs["partition_by"] and any( + not node.is_elementwise() for node in new_nodes[:i] + ): + new_nodes.insert(i, node_without_order_by) + elif all(node.is_elementwise() for node in new_nodes): + msg = "Cannot apply `over` to elementwise expression." + raise InvalidOperationError(msg) + return self.__class__(*new_nodes) def __repr__(self) -> str: - return f"Narwhals Expr\nmetadata: {self._metadata}\n" + """Pretty-print the expression by combining all nodes in the metadata.""" + return ".".join(repr(node) for node in self._nodes) def __bool__(self) -> NoReturn: msg = ( @@ -124,9 +111,7 @@ def __bool__(self) -> NoReturn: def _taxicab_norm(self) -> Self: # This is just used to test out the stable api feature in a realistic-ish way. # It's not intended to be used. - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).abs().sum() - ) + return self.abs().sum() # --- convert --- def alias(self, name: str) -> Self: @@ -149,10 +134,7 @@ def alias(self, name: str) -> Self: | 1 15 | └──────────────────┘ """ - # Don't use `_with_elementwise` so that `_metadata.last_node` is preserved. - return self.__class__( - lambda plx: self._to_compliant_expr(plx).alias(name), self._metadata - ) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "alias", name=name)) def pipe( self, @@ -207,102 +189,88 @@ def cast(self, dtype: IntoDType) -> Self: └──────────────────┘ """ _validate_dtype(dtype) - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).cast(dtype) - ) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "cast", dtype=dtype)) # --- binary --- - def _with_binary( - self, - function: Callable[[Any, Any], Any], - other: Self | Any, - *, - str_as_lit: bool = True, - ) -> Self: - return self.__class__( - lambda plx: apply_n_ary_operation( - plx, function, self, other, str_as_lit=str_as_lit - ), - ExprMetadata.from_binary_op(self, other), - ) + def _with_binary(self, attr: str, other: Self | Any) -> Self: + node = ExprNode(ExprKind.ELEMENTWISE, attr, other, str_as_lit=True) + return self._append_node(node) def __eq__(self, other: Self | Any) -> Self: # type: ignore[override] - return self._with_binary(op.eq, other) + return self._with_binary("__eq__", other) def __ne__(self, other: Self | Any) -> Self: # type: ignore[override] - return self._with_binary(op.ne, other) + return self._with_binary("__ne__", other) def __and__(self, other: Any) -> Self: - return self._with_binary(op.and_, other) + return self._with_binary("__and__", other) def __rand__(self, other: Any) -> Self: return (self & other).alias("literal") # type: ignore[no-any-return] def __or__(self, other: Any) -> Self: - return self._with_binary(op.or_, other) + return self._with_binary("__or__", other) def __ror__(self, other: Any) -> Self: return (self | other).alias("literal") # type: ignore[no-any-return] def __add__(self, other: Any) -> Self: - return self._with_binary(op.add, other) + return self._with_binary("__add__", other) def __radd__(self, other: Any) -> Self: return (self + other).alias("literal") # type: ignore[no-any-return] def __sub__(self, other: Any) -> Self: - return self._with_binary(op.sub, other) + return self._with_binary("__sub__", other) def __rsub__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rsub__(y), other) + return self._with_binary("__rsub__", other) def __truediv__(self, other: Any) -> Self: - return self._with_binary(op.truediv, other) + return self._with_binary("__truediv__", other) def __rtruediv__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rtruediv__(y), other) + return self._with_binary("__rtruediv__", other) def __mul__(self, other: Any) -> Self: - return self._with_binary(op.mul, other) + return self._with_binary("__mul__", other) def __rmul__(self, other: Any) -> Self: return (self * other).alias("literal") # type: ignore[no-any-return] def __le__(self, other: Any) -> Self: - return self._with_binary(op.le, other) + return self._with_binary("__le__", other) def __lt__(self, other: Any) -> Self: - return self._with_binary(op.lt, other) + return self._with_binary("__lt__", other) def __gt__(self, other: Any) -> Self: - return self._with_binary(op.gt, other) + return self._with_binary("__gt__", other) def __ge__(self, other: Any) -> Self: - return self._with_binary(op.ge, other) + return self._with_binary("__ge__", other) def __pow__(self, other: Any) -> Self: - return self._with_binary(op.pow, other) + return self._with_binary("__pow__", other) def __rpow__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rpow__(y), other) + return self._with_binary("__rpow__", other) def __floordiv__(self, other: Any) -> Self: - return self._with_binary(op.floordiv, other) + return self._with_binary("__floordiv__", other) def __rfloordiv__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rfloordiv__(y), other) + return self._with_binary("__rfloordiv__", other) def __mod__(self, other: Any) -> Self: - return self._with_binary(op.mod, other) + return self._with_binary("__mod__", other) def __rmod__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rmod__(y), other) + return self._with_binary("__rmod__", other) # --- unary --- def __invert__(self) -> Self: - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).__invert__() - ) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "__invert__")) def any(self) -> Self: """Return whether any of the values in the column are `True`. @@ -322,7 +290,7 @@ def any(self) -> Self: | 0 True True | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).any()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "any")) def all(self) -> Self: """Return whether all values in the column are `True`. @@ -342,7 +310,7 @@ def all(self) -> Self: | 0 False True | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).all()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "all")) def ewm_mean( self, @@ -427,8 +395,10 @@ def ewm_mean( │ 2.428571 │ └──────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).ewm_mean( + return self._append_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "ewm_mean", com=com, span=span, half_life=half_life, @@ -455,7 +425,7 @@ def mean(self) -> Self: | 0 0.0 4.0 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).mean()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "mean")) def median(self) -> Self: """Get median value. @@ -476,7 +446,7 @@ def median(self) -> Self: | 0 3.0 4.0 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).median()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "median")) def std(self, *, ddof: int = 1) -> Self: """Get standard deviation. @@ -498,9 +468,7 @@ def std(self, *, ddof: int = 1) -> Self: |0 17.79513 1.265789| └─────────────────────┘ """ - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).std(ddof=ddof) - ) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "std", ddof=ddof)) def var(self, *, ddof: int = 1) -> Self: """Get variance. @@ -522,9 +490,7 @@ def var(self, *, ddof: int = 1) -> Self: |0 316.666667 1.602222| └───────────────────────┘ """ - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).var(ddof=ddof) - ) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "var", ddof=ddof)) def map_batches( self, @@ -568,18 +534,20 @@ def map_batches( |2 3 6 4.0 7.0| └───────────────────────────┘ """ - - def compliant_expr(plx: Any) -> Any: - return self._to_compliant_expr(plx).map_batches( + kind = ( + ExprKind.ORDERABLE_AGGREGATION + if returns_scalar + else ExprKind.ORDERABLE_FILTRATION + ) + return self._append_node( + ExprNode( + kind, + "map_batches", function=function, return_dtype=return_dtype, returns_scalar=returns_scalar, ) - - if returns_scalar: - return self._with_orderable_aggregation(compliant_expr) - # safest assumptions - return self._with_orderable_filtration(compliant_expr) + ) def skew(self) -> Self: """Calculate the sample skewness of a column. @@ -597,7 +565,7 @@ def skew(self) -> Self: | 0 0.0 1.472427 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).skew()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "skew")) def kurtosis(self) -> Self: """Compute the kurtosis (Fisher's definition) without bias correction. @@ -618,9 +586,9 @@ def kurtosis(self) -> Self: | 0 -1.3 0.210657 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).kurtosis()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "kurtosis")) - def sum(self) -> Expr: + def sum(self) -> Self: """Return the sum value. If there are no non-null elements, the result is zero. @@ -642,7 +610,7 @@ def sum(self) -> Expr: |└────────┴────────┘| └───────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).sum()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "sum")) def min(self) -> Self: """Returns the minimum value(s) from a column(s). @@ -660,7 +628,7 @@ def min(self) -> Self: | 0 1 3 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).min()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "min")) def max(self) -> Self: """Returns the maximum value(s) from a column(s). @@ -678,7 +646,7 @@ def max(self) -> Self: | 0 20 100 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).max()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "max")) def count(self) -> Self: """Returns the number of non-null elements in the column. @@ -696,7 +664,7 @@ def count(self) -> Self: | 0 3 2 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).count()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "count")) def n_unique(self) -> Self: """Returns count of unique values. @@ -714,7 +682,7 @@ def n_unique(self) -> Self: | 0 5 3 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).n_unique()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "n_unique")) def unique(self) -> Self: """Return unique values of this expression. @@ -732,7 +700,7 @@ def unique(self) -> Self: | 0 9 12 | └──────────────────┘ """ - return self._with_filtration(lambda plx: self._to_compliant_expr(plx).unique()) + return self._append_node(ExprNode(ExprKind.FILTRATION, "unique")) def abs(self) -> Self: """Return absolute value of each element. @@ -751,7 +719,7 @@ def abs(self) -> Self: |1 -2 4 2 4| └─────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).abs()) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "abs")) def cum_sum(self, *, reverse: bool = False) -> Self: """Return cumulative sum. @@ -780,8 +748,8 @@ def cum_sum(self, *, reverse: bool = False) -> Self: |4 5 6 15| └──────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_sum(reverse=reverse) + return self._append_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_sum", reverse=reverse) ) def diff(self) -> Self: @@ -823,9 +791,7 @@ def diff(self) -> Self: | └─────┴────────┘ | └──────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).diff() - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "diff")) def shift(self, n: int) -> Self: """Shift values by `n` positions. @@ -870,10 +836,7 @@ def shift(self, n: int) -> Self: └──────────────────┘ """ ensure_type(n, int, param_name="n") - - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).shift(n) - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "shift", n=n)) def replace_strict( self, @@ -925,9 +888,13 @@ def replace_strict( new = list(old.values()) old = list(old.keys()) - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).replace_strict( - old, new, return_dtype=return_dtype + return self._append_node( + ExprNode( + ExprKind.ELEMENTWISE, + "replace_strict", + old=old, + new=new, + return_dtype=return_dtype, ) ) @@ -962,11 +929,10 @@ def is_between( | 4 5 False | └──────────────────┘ """ - return self._with_nary( - lambda expr, lb, ub: expr.is_between(lb, ub, closed=closed), - lower_bound, - upper_bound, + node = ExprNode( + ExprKind.ELEMENTWISE, "is_between", lower_bound, upper_bound, closed=closed ) + return self._append_node(node) def is_in(self, other: Any) -> Self: """Check if elements of this expression are present in the other iterable. @@ -991,9 +957,11 @@ def is_in(self, other: Any) -> Self: └──────────────────┘ """ if isinstance(other, Iterable) and not isinstance(other, (str, bytes)): - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).is_in( - to_native(other, pass_through=True) + return self._append_node( + ExprNode( + ExprKind.ELEMENTWISE, + "is_in", + other=to_native(other, pass_through=True), ) ) msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead." @@ -1025,24 +993,7 @@ def filter(self, *predicates: Any) -> Self: | 5 7 12 | └──────────────────┘ """ - flat_predicates = flatten(predicates) - metadata = combine_metadata( - self, - *flat_predicates, - str_as_lit=False, - allow_multi_output=True, - to_single_output=False, - ).with_filtration() - return self.__class__( - lambda plx: apply_n_ary_operation( - plx, - lambda *exprs: exprs[0].filter(*exprs[1:]), - self, - *flat_predicates, - str_as_lit=False, - ), - metadata, - ) + return self._append_node(ExprNode(ExprKind.FILTRATION, "filter", *predicates)) def is_null(self) -> Self: """Returns a boolean Series indicating which values are null. @@ -1073,7 +1024,7 @@ def is_null(self) -> Self: |└───────┴────────┴───────────┴───────────┘| └──────────────────────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).is_null()) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "is_null")) def is_nan(self) -> Self: """Indicate which values are NaN. @@ -1104,7 +1055,7 @@ def is_nan(self) -> Self: |└───────┴────────┴──────────┴──────────┘| └────────────────────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).is_nan()) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "is_nan")) def fill_null( self, @@ -1191,20 +1142,24 @@ def fill_null( msg = f"strategy not supported: {strategy}" raise ValueError(msg) - return self.__class__( - lambda plx: apply_n_ary_operation( - plx, - lambda *exprs: exprs[0].fill_null( - exprs[1], strategy=strategy, limit=limit - ), - self, + if strategy is not None: + node = ExprNode( + ExprKind.ORDERABLE_WINDOW, + "fill_null", + value=value, + strategy=strategy, + limit=limit, + ) + else: + node = ExprNode( + ExprKind.ELEMENTWISE, + "fill_null", value, + strategy=strategy, + limit=limit, str_as_lit=True, - ), - self._metadata.with_orderable_window() - if strategy is not None - else self._metadata, - ) + ) + return self._append_node(node) def fill_nan(self, value: float | None) -> Self: """Fill floating point NaN values with given value. @@ -1237,9 +1192,7 @@ def fill_nan(self, value: float | None) -> Self: |└────────┴────────┴───────────────┴───────────────┘| └───────────────────────────────────────────────────┘ """ - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).fill_nan(value) - ) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "fill_nan", value=value)) # --- partial reduction --- def drop_nulls(self) -> Self: @@ -1272,9 +1225,7 @@ def drop_nulls(self) -> Self: | └─────┘ | └──────────────────┘ """ - return self._with_filtration( - lambda plx: self._to_compliant_expr(plx).drop_nulls() - ) + return self._append_node(ExprNode(ExprKind.FILTRATION, "drop_nulls")) def over( self, @@ -1324,22 +1275,10 @@ def over( if not flat_partition_by and not flat_order_by: # pragma: no cover msg = "At least one of `partition_by` or `order_by` must be specified." raise ValueError(msg) - - current_meta = self._metadata - if flat_order_by: - next_meta = current_meta.with_ordered_over() - elif not flat_partition_by: # pragma: no cover - msg = "At least one of `partition_by` or `order_by` must be specified." - raise InvalidOperationError(msg) - else: - next_meta = current_meta.with_partitioned_over() - - return self.__class__( - lambda plx: self._to_compliant_expr(plx).over( - flat_partition_by, flat_order_by - ), - next_meta, + node = ExprNode( + ExprKind.OVER, "over", partition_by=flat_partition_by, order_by=flat_order_by ) + return self._with_over_node(node) def is_duplicated(self) -> Self: r"""Return a boolean mask indicating duplicated values. @@ -1360,7 +1299,7 @@ def is_duplicated(self) -> Self: |3 1 c True False| └─────────────────────────────────────────┘ """ - return self._with_window(lambda plx: self._to_compliant_expr(plx).is_duplicated()) + return self._append_node(ExprNode(ExprKind.WINDOW, "is_duplicated")) def is_unique(self) -> Self: r"""Return a boolean mask indicating unique values. @@ -1381,7 +1320,7 @@ def is_unique(self) -> Self: |3 1 c False True| └─────────────────────────────────┘ """ - return self._with_window(lambda plx: self._to_compliant_expr(plx).is_unique()) + return self._append_node(ExprNode(ExprKind.WINDOW, "is_unique")) def null_count(self) -> Self: r"""Count null values. @@ -1405,9 +1344,7 @@ def null_count(self) -> Self: | 0 1 2 | └──────────────────┘ """ - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).null_count() - ) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "null_count")) def is_first_distinct(self) -> Self: r"""Return a boolean mask indicating the first occurrence of each distinct value. @@ -1434,9 +1371,7 @@ def is_first_distinct(self) -> Self: |3 1 c False True| └─────────────────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).is_first_distinct() - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "is_first_distinct")) def is_last_distinct(self) -> Self: r"""Return a boolean mask indicating the last occurrence of each distinct value. @@ -1463,9 +1398,7 @@ def is_last_distinct(self) -> Self: |3 1 c True True| └───────────────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).is_last_distinct() - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "is_last_distinct")) def quantile( self, quantile: float, interpolation: RollingInterpolationMethod @@ -1498,8 +1431,13 @@ def quantile( | 0 24.5 74.5 | └──────────────────┘ """ - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).quantile(quantile, interpolation) + return self._append_node( + ExprNode( + ExprKind.AGGREGATION, + "quantile", + quantile=quantile, + interpolation=interpolation, + ) ) def round(self, decimals: int = 0) -> Self: @@ -1532,8 +1470,8 @@ def round(self, decimals: int = 0) -> Self: |2 3.901234 3.9| └──────────────────────┘ """ - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).round(decimals) + return self._append_node( + ExprNode(ExprKind.ELEMENTWISE, "round", decimals=decimals) ) def floor(self) -> Self: @@ -1557,7 +1495,7 @@ def floor(self) -> Self: |floor: [[1,4,-2]] | └────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).floor()) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "floor")) def ceil(self) -> Self: r"""Compute the numerical ceiling. @@ -1580,7 +1518,7 @@ def ceil(self) -> Self: |ceil: [[2,5,-1]] | └────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).ceil()) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "ceil")) def len(self) -> Self: r"""Return the number of elements in the column. @@ -1603,7 +1541,7 @@ def len(self) -> Self: | 0 2 1 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).len()) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "len")) def clip( self, @@ -1631,13 +1569,16 @@ def clip( | 2 3 3 | └──────────────────┘ """ - return self._with_nary( - lambda *exprs: exprs[0].clip( - exprs[1] if lower_bound is not None else None, - exprs[2] if upper_bound is not None else None, - ), - lower_bound, - upper_bound, + if upper_bound is None: + return self._append_node( + ExprNode(ExprKind.ELEMENTWISE, "clip_lower", lower_bound) + ) + if lower_bound is None: + return self._append_node( + ExprNode(ExprKind.ELEMENTWISE, "clip_upper", upper_bound) + ) + return self._append_node( + ExprNode(ExprKind.ELEMENTWISE, "clip", lower_bound, upper_bound) ) def first(self) -> Self: @@ -1670,9 +1611,7 @@ def first(self) -> Self: | 1 2 None | └──────────────────┘ """ - return self._with_orderable_aggregation( - lambda plx: self._to_compliant_expr(plx).first() - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "first")) def last(self) -> Self: """Get the last value. @@ -1711,9 +1650,7 @@ def last(self) -> Self: |b: [[null,"baz"]] | └──────────────────┘ """ - return self._with_orderable_aggregation( - lambda plx: self._to_compliant_expr(plx).last() - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "last")) def mode(self, *, keep: ModeKeepStrategy = "all") -> Self: r"""Compute the most occurring value(s). @@ -1741,13 +1678,8 @@ def mode(self, *, keep: ModeKeepStrategy = "all") -> Self: if keep not in _supported_keep_values: # pragma: no cover msg = f"`keep` must be one of {_supported_keep_values}, found '{keep}'" raise ValueError(msg) - - def compliant_expr(plx: Any) -> Any: - return self._to_compliant_expr(plx).mode(keep=keep) - - if keep == "any": - return self._with_aggregation(compliant_expr) - return self._with_filtration(compliant_expr) + kind = ExprKind.AGGREGATION if keep == "any" else ExprKind.FILTRATION + return self._append_node(ExprNode(kind, "mode", keep=keep)) def is_finite(self) -> Self: """Returns boolean values indicating which original values are finite. @@ -1781,9 +1713,7 @@ def is_finite(self) -> Self: |└──────┴─────────────┘| └──────────────────────┘ """ - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).is_finite() - ) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "is_finite")) def cum_count(self, *, reverse: bool = False) -> Self: r"""Return the cumulative count of the non-null values in the column. @@ -1814,8 +1744,8 @@ def cum_count(self, *, reverse: bool = False) -> Self: |3 d 3 1| └─────────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_count(reverse=reverse) + return self._append_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_count", reverse=reverse) ) def cum_min(self, *, reverse: bool = False) -> Self: @@ -1847,8 +1777,8 @@ def cum_min(self, *, reverse: bool = False) -> Self: |3 2.0 1.0 2.0| └────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_min(reverse=reverse) + return self._append_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_min", reverse=reverse) ) def cum_max(self, *, reverse: bool = False) -> Self: @@ -1880,8 +1810,8 @@ def cum_max(self, *, reverse: bool = False) -> Self: |3 2.0 3.0 2.0| └────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_max(reverse=reverse) + return self._append_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_max", reverse=reverse) ) def cum_prod(self, *, reverse: bool = False) -> Self: @@ -1913,8 +1843,8 @@ def cum_prod(self, *, reverse: bool = False) -> Self: |3 2.0 6.0 2.0| └──────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_prod(reverse=reverse) + return self._append_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_prod", reverse=reverse) ) def rolling_sum( @@ -1959,13 +1889,16 @@ def rolling_sum( |3 4.0 6.0| └─────────────────────┘ """ - window_size, min_samples_int = _validate_rolling_arguments( + window_size, min_samples = _validate_rolling_arguments( window_size=window_size, min_samples=min_samples ) - - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).rolling_sum( - window_size=window_size, min_samples=min_samples_int, center=center + return self._append_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "rolling_sum", + window_size=window_size, + min_samples=min_samples, + center=center, ) ) @@ -2015,9 +1948,13 @@ def rolling_mean( window_size=window_size, min_samples=min_samples ) - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).rolling_mean( - window_size=window_size, min_samples=min_samples, center=center + return self._append_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "rolling_mean", + window_size=window_size, + min_samples=min_samples, + center=center, ) ) @@ -2072,10 +2009,14 @@ def rolling_var( window_size, min_samples = _validate_rolling_arguments( window_size=window_size, min_samples=min_samples ) - - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).rolling_var( - window_size=window_size, min_samples=min_samples, center=center, ddof=ddof + return self._append_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "rolling_var", + ddof=ddof, + window_size=window_size, + min_samples=min_samples, + center=center, ) ) @@ -2130,10 +2071,14 @@ def rolling_std( window_size, min_samples = _validate_rolling_arguments( window_size=window_size, min_samples=min_samples ) - - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).rolling_std( - window_size=window_size, min_samples=min_samples, center=center, ddof=ddof + return self._append_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "rolling_std", + ddof=ddof, + window_size=window_size, + min_samples=min_samples, + center=center, ) ) @@ -2191,10 +2136,8 @@ def rank(self, method: RankMethod = "average", *, descending: bool = False) -> S ) raise ValueError(msg) - return self._with_window( - lambda plx: self._to_compliant_expr(plx).rank( - method=method, descending=descending - ) + return self._append_node( + ExprNode(ExprKind.WINDOW, "rank", method=method, descending=descending) ) def log(self, base: float = math.e) -> Self: @@ -2225,9 +2168,7 @@ def log(self, base: float = math.e) -> Self: |log_2: [[0,1,2]] | └────────────────────────────────────────────────┘ """ - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).log(base=base) - ) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "log", base=base)) def exp(self) -> Self: r"""Compute the exponent. @@ -2250,7 +2191,7 @@ def exp(self) -> Self: |exp: [[0.36787944117144233,1,2.718281828459045]]| └────────────────────────────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).exp()) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "exp")) def sqrt(self) -> Self: r"""Compute the square root. @@ -2273,11 +2214,11 @@ def sqrt(self) -> Self: |sqrt: [[1,2,3]] | └──────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).sqrt()) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "sqrt")) - def is_close( + def is_close( # noqa: PLR0914 self, - other: Self | NumericLiteral, + other: Expr | Series[Any] | NumericLiteral, *, abs_tol: float = 0.0, rel_tol: float = 1e-09, @@ -2345,11 +2286,58 @@ def is_close( msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}" raise ComputeError(msg) - kwargs = {"abs_tol": abs_tol, "rel_tol": rel_tol, "nans_equal": nans_equal} - return self._with_nary( - lambda *exprs: exprs[0].is_close(exprs[1], **kwargs), other + from decimal import Decimal + + other_abs: Expr | Series[Any] | NumericLiteral + other_is_nan: Expr | Series[Any] | bool + other_is_inf: Expr | Series[Any] | bool + other_is_not_inf: Expr | Series[Any] | bool + + if isinstance(other, (float, int, Decimal)): + from math import isinf, isnan + + # NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447 + other_abs = other.__abs__() + other_is_nan = isnan(other) + other_is_inf = isinf(other) + + # Define the other_is_not_inf variable to prevent triggering the following warning: + # > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be + # > removed in Python 3.16. + other_is_not_inf = not other_is_inf + + else: + other_abs, other_is_nan = other.abs(), other.is_nan() + other_is_not_inf = other.is_finite() | other_is_nan + other_is_inf = ~other_is_not_inf + + rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol + tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None) + + self_is_nan = self.is_nan() + self_is_not_inf = self.is_finite() | self_is_nan + + # Values are close if abs_diff <= tolerance, and both finite + is_close = ( + ((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf ) + # Handle infinity cases: infinities are close/equal if they have the same sign + self_sign, other_sign = self > 0, other > 0 + is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign) + + # Handle nan cases: + # * If any value is NaN, then False (via `& ~either_nan`) + # * However, if `nans_equals = True` and if _both_ values are NaN, then True + either_nan = self_is_nan | other_is_nan + result = (is_close | is_same_inf) & ~either_nan + + if nans_equal: + both_nan = self_is_nan & other_is_nan + result = result | both_nan + + return result + @property def str(self) -> ExprStringNamespace[Self]: return ExprStringNamespace(self) diff --git a/narwhals/expr_cat.py b/narwhals/expr_cat.py index 5e7229bf30..ff29bdd64d 100644 --- a/narwhals/expr_cat.py +++ b/narwhals/expr_cat.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr @@ -34,6 +36,6 @@ def get_categories(self) -> ExprT: │ mango │ └────────┘ """ - return self._expr._with_filtration( - lambda plx: self._expr._to_compliant_expr(plx).cat.get_categories() + return self._expr._append_node( + ExprNode(ExprKind.FILTRATION, "cat.get_categories") ) diff --git a/narwhals/expr_dt.py b/narwhals/expr_dt.py index 9ae6eac38f..3b8c31055e 100644 --- a/narwhals/expr_dt.py +++ b/narwhals/expr_dt.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr from narwhals.typing import TimeUnit @@ -38,9 +40,7 @@ def date(self) -> ExprT: │ 2027-12-13 │ └────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.date() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.date")) def year(self) -> ExprT: """Extract year from underlying DateTime representation. @@ -64,9 +64,7 @@ def year(self) -> ExprT: |1 2065-01-01 2065| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.year() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.year")) def month(self) -> ExprT: """Extract month from underlying DateTime representation. @@ -87,9 +85,7 @@ def month(self) -> ExprT: a: [[1978-06-01 00:00:00.000000,2065-01-01 00:00:00.000000]] month: [[6,1]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.month() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.month")) def day(self) -> ExprT: """Extract day from underlying DateTime representation. @@ -110,9 +106,7 @@ def day(self) -> ExprT: a: [[1978-06-01 00:00:00.000000,2065-01-01 00:00:00.000000]] day: [[1,1]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.day() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.day")) def hour(self) -> ExprT: """Extract hour from underlying DateTime representation. @@ -142,9 +136,7 @@ def hour(self) -> ExprT: |└─────────────────────┴──────┘| └──────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.hour() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.hour")) def minute(self) -> ExprT: """Extract minutes from underlying DateTime representation. @@ -164,9 +156,7 @@ def minute(self) -> ExprT: 0 1978-01-01 01:01:00 1 1 2065-01-01 10:20:00 20 """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.minute() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.minute")) def second(self) -> ExprT: """Extract seconds from underlying DateTime representation. @@ -192,9 +182,7 @@ def second(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.000000]] second: [[1,30]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.second() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.second")) def millisecond(self) -> ExprT: """Extract milliseconds from underlying DateTime representation. @@ -222,9 +210,7 @@ def millisecond(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.067000]] millisecond: [[0,67]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.millisecond() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.millisecond")) def microsecond(self) -> ExprT: """Extract microseconds from underlying DateTime representation. @@ -252,9 +238,7 @@ def microsecond(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.067000]] microsecond: [[0,67000]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.microsecond() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.microsecond")) def nanosecond(self) -> ExprT: """Extract Nanoseconds from underlying DateTime representation. @@ -282,9 +266,7 @@ def nanosecond(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.067000]] nanosecond: [[0,67000000]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.nanosecond() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.nanosecond")) def ordinal_day(self) -> ExprT: """Get ordinal day. @@ -306,9 +288,7 @@ def ordinal_day(self) -> ExprT: |1 2020-08-03 216| └───────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.ordinal_day() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.ordinal_day")) def weekday(self) -> ExprT: """Extract the week day from the underlying Date representation. @@ -332,9 +312,7 @@ def weekday(self) -> ExprT: |1 2020-08-03 1| └────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.weekday() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.weekday")) def total_minutes(self) -> ExprT: """Get total minutes. @@ -365,9 +343,7 @@ def total_minutes(self) -> ExprT: │ 20m 40s ┆ 20 │ └──────────────┴─────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_minutes() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.total_minutes")) def total_seconds(self) -> ExprT: """Get total seconds. @@ -398,9 +374,7 @@ def total_seconds(self) -> ExprT: │ 20s 40ms ┆ 20 │ └──────────────┴─────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_seconds() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.total_seconds")) def total_milliseconds(self) -> ExprT: """Get total milliseconds. @@ -436,8 +410,8 @@ def total_milliseconds(self) -> ExprT: │ 20040µs ┆ 20 │ └──────────────┴──────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_milliseconds() + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.total_milliseconds") ) def total_microseconds(self) -> ExprT: @@ -471,8 +445,8 @@ def total_microseconds(self) -> ExprT: a: [[10,1200]] a_total_microseconds: [[10,1200]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_microseconds() + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.total_microseconds") ) def total_nanoseconds(self) -> ExprT: @@ -505,8 +479,8 @@ def total_nanoseconds(self) -> ExprT: 0 2024-01-01 00:00:00.000000001 NaN 1 2024-01-01 00:00:00.000000002 1.0 """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_nanoseconds() + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.total_nanoseconds") ) def to_string(self, format: str) -> ExprT: @@ -569,8 +543,8 @@ def to_string(self, format: str) -> ExprT: |└─────────────────────┘| └───────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.to_string(format) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.to_string", format=format) ) def replace_time_zone(self, time_zone: str | None) -> ExprT: @@ -597,8 +571,8 @@ def replace_time_zone(self, time_zone: str | None) -> ExprT: 0 2024-01-01 00:00:00+05:45 1 2024-01-02 00:00:00+05:45 """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.replace_time_zone(time_zone) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.replace_time_zone", time_zone=time_zone) ) def convert_time_zone(self, time_zone: str) -> ExprT: @@ -631,8 +605,8 @@ def convert_time_zone(self, time_zone: str) -> ExprT: if time_zone is None: msg = "Target `time_zone` cannot be `None` in `convert_time_zone`. Please use `replace_time_zone(None)` if you want to remove the time zone." raise TypeError(msg) - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.convert_time_zone(time_zone) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.convert_time_zone", time_zone=time_zone) ) def timestamp(self, time_unit: TimeUnit = "us") -> ExprT: @@ -671,8 +645,8 @@ def timestamp(self, time_unit: TimeUnit = "us") -> ExprT: f"\n\nExpected one of {{'ns', 'us', 'ms'}}, got {time_unit!r}." ) raise ValueError(msg) - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.timestamp(time_unit) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.timestamp", time_unit=time_unit) ) def truncate(self, every: str) -> ExprT: @@ -715,8 +689,8 @@ def truncate(self, every: str) -> ExprT: |└─────────────────────┴─────────────────────┘| └─────────────────────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.truncate(every) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.truncate", every=every) ) def offset_by(self, by: str) -> ExprT: @@ -759,6 +733,6 @@ def offset_by(self, by: str) -> ExprT: |└─────────────────────┴───────────────────────┘| └───────────────────────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.offset_by(by) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.offset_by", by=by) ) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index fc01bca035..8f9c94c6ab 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr from narwhals.typing import NonNestedLiteral @@ -40,9 +42,7 @@ def len(self) -> ExprT: |└──────────────┴───────┘| └────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).list.len() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.len")) def unique(self) -> ExprT: """Get the unique/distinct values in the list. @@ -71,9 +71,7 @@ def unique(self) -> ExprT: |└──────────────┴───────────┘| └────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).list.unique() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.unique")) def contains(self, item: NonNestedLiteral) -> ExprT: """Check if sublists contain the given item. @@ -102,8 +100,8 @@ def contains(self, item: NonNestedLiteral) -> ExprT: |└───────────┴──────────────┘| └────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).list.contains(item) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "list.contains", item=item) ) def get(self, index: int) -> ExprT: @@ -142,6 +140,6 @@ def get(self, index: int) -> ExprT: msg = f"Index {index} is out of bounds: should be greater than or equal to 0." raise ValueError(msg) - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).list.get(index) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "list.get", index=index) ) diff --git a/narwhals/expr_name.py b/narwhals/expr_name.py index facda33042..64525aae56 100644 --- a/narwhals/expr_name.py +++ b/narwhals/expr_name.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Callable, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr @@ -26,9 +28,7 @@ def keep(self) -> ExprT: >>> df.select(nw.col("foo").alias("alias_for_foo").name.keep()).columns ['foo'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.keep() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "name.keep")) def map(self, function: Callable[[str], str]) -> ExprT: r"""Rename the output of an expression by mapping a function over the root name. @@ -48,8 +48,8 @@ def map(self, function: Callable[[str], str]) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.map(renaming_func)).columns ['oof', 'RAB'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.map(function) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "name.map", function=function) ) def prefix(self, prefix: str) -> ExprT: @@ -69,8 +69,8 @@ def prefix(self, prefix: str) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.prefix("with_prefix")).columns ['with_prefixfoo', 'with_prefixBAR'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.prefix(prefix) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "name.prefix", prefix=prefix) ) def suffix(self, suffix: str) -> ExprT: @@ -90,8 +90,8 @@ def suffix(self, suffix: str) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.suffix("_with_suffix")).columns ['foo_with_suffix', 'BAR_with_suffix'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.suffix(suffix) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "name.suffix", suffix=suffix) ) def to_lowercase(self) -> ExprT: @@ -108,8 +108,8 @@ def to_lowercase(self) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.to_lowercase()).columns ['foo', 'bar'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.to_lowercase() + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "name.to_lowercase") ) def to_uppercase(self) -> ExprT: @@ -126,6 +126,6 @@ def to_uppercase(self) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.to_uppercase()).columns ['FOO', 'BAR'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.to_uppercase() + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "name.to_uppercase") ) diff --git a/narwhals/expr_str.py b/narwhals/expr_str.py index b64d4580dc..fe914e8375 100644 --- a/narwhals/expr_str.py +++ b/narwhals/expr_str.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING, Generic, TypeVar -from narwhals._expression_parsing import apply_n_ary_operation +from narwhals._expression_parsing import ExprKind, ExprNode if TYPE_CHECKING: from narwhals.expr import Expr + from narwhals.typing import IntoExpr ExprT = TypeVar("ExprT", bound="Expr") @@ -38,12 +39,10 @@ def len_chars(self) -> ExprT: |└───────┴───────────┘| └─────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.len_chars() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.len_chars")) def replace( - self, pattern: str, value: str | ExprT, *, literal: bool = False, n: int = 1 + self, pattern: str, value: str | IntoExpr, *, literal: bool = False, n: int = 1 ) -> ExprT: r"""Replace first matching regex/literal substring with a new string value. @@ -67,22 +66,20 @@ def replace( |1 abc abc123 abc123| └──────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: ( - apply_n_ary_operation( - plx, - lambda self, value: self.str.replace( - pattern, value, literal=literal, n=n - ), - self._expr, - value, - str_as_lit=True, - ) + return self._expr._append_node( + ExprNode( + ExprKind.ELEMENTWISE, + "str.replace", + value, + pattern=pattern, + literal=literal, + n=n, + str_as_lit=True, ) ) def replace_all( - self, pattern: str, value: str | ExprT, *, literal: bool = False + self, pattern: str, value: IntoExpr, *, literal: bool = False ) -> ExprT: r"""Replace all matching regex/literal substring with a new string value. @@ -105,17 +102,14 @@ def replace_all( |1 abc abc123 123| └──────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: ( - apply_n_ary_operation( - plx, - lambda self, value: self.str.replace_all( - pattern, value, literal=literal - ), - self._expr, - value, - str_as_lit=True, - ) + return self._expr._append_node( + ExprNode( + ExprKind.ELEMENTWISE, + "str.replace_all", + value, + pattern=pattern, + literal=literal, + str_as_lit=True, ) ) @@ -138,8 +132,8 @@ def strip_chars(self, characters: str | None = None) -> ExprT: ... ) {'fruits': ['apple', '\nmango'], 'stripped': ['apple', 'mango']} """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.strip_chars(characters) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "str.strip_chars", characters=characters) ) def starts_with(self, prefix: str) -> ExprT: @@ -163,8 +157,8 @@ def starts_with(self, prefix: str) -> ExprT: |2 None None| └───────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.starts_with(prefix) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "str.starts_with", prefix=prefix) ) def ends_with(self, suffix: str) -> ExprT: @@ -188,8 +182,8 @@ def ends_with(self, suffix: str) -> ExprT: |2 None None| └───────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.ends_with(suffix) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "str.ends_with", suffix=suffix) ) def contains(self, pattern: str, *, literal: bool = False) -> ExprT: @@ -218,9 +212,9 @@ def contains(self, pattern: str, *, literal: bool = False) -> ExprT: default_match: [[true,false,true]] case_insensitive_match: [[true,false,true]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.contains( - pattern, literal=literal + return self._expr._append_node( + ExprNode( + ExprKind.ELEMENTWISE, "str.contains", pattern=pattern, literal=literal ) ) @@ -247,10 +241,8 @@ def slice(self, offset: int, length: int | None = None) -> ExprT: |2 papaya ya| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.slice( - offset=offset, length=length - ) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "str.slice", offset=offset, length=length) ) def split(self, by: str) -> ExprT: @@ -279,9 +271,7 @@ def split(self, by: str) -> ExprT: |└─────────┴────────────────┘| └────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.split(by=by) - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.split", by=by)) def head(self, n: int = 5) -> ExprT: r"""Take the first n elements of each string. @@ -305,8 +295,8 @@ def head(self, n: int = 5) -> ExprT: lyrics: [["taata","taatatata","zukkyun"]] lyrics_head: [["taata","taata","zukky"]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.slice(0, n) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "str.slice", offset=0, length=n) ) def tail(self, n: int = 5) -> ExprT: @@ -331,10 +321,8 @@ def tail(self, n: int = 5) -> ExprT: lyrics: [["taata","taatatata","zukkyun"]] lyrics_tail: [["taata","atata","kkyun"]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.slice( - offset=-n, length=None - ) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "str.slice", offset=-n, length=None) ) def to_datetime(self, format: str | None = None) -> ExprT: @@ -375,8 +363,8 @@ def to_datetime(self, format: str | None = None) -> ExprT: |└─────────────────────┘| └───────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_datetime(format=format) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "str.to_datetime", format=format) ) def to_date(self, format: str | None = None) -> ExprT: @@ -404,8 +392,8 @@ def to_date(self, format: str | None = None) -> ExprT: |a: [[2020-01-01,2020-01-02]]| └────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_date(format=format) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "str.to_date", format=format) ) def to_uppercase(self) -> ExprT: @@ -430,9 +418,7 @@ def to_uppercase(self) -> ExprT: |1 None None| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_uppercase() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_uppercase")) def to_lowercase(self) -> ExprT: r"""Transform string to lowercase variant. @@ -451,9 +437,7 @@ def to_lowercase(self) -> ExprT: |1 None None| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_lowercase() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_lowercase")) def to_titlecase(self) -> ExprT: """Modify strings to their titlecase equivalent. @@ -507,9 +491,7 @@ def to_titlecase(self) -> ExprT: |└─────────────────────────┴─────────────────────────┘| └─────────────────────────────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_titlecase() - ) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_titlecase")) def zfill(self, width: int) -> ExprT: """Transform string to zero-padded variant. @@ -535,6 +517,6 @@ def zfill(self, width: int) -> ExprT: |3 None None| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.zfill(width) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "str.zfill", width=width) ) diff --git a/narwhals/expr_struct.py b/narwhals/expr_struct.py index fe74cf9f75..7d734732f9 100644 --- a/narwhals/expr_struct.py +++ b/narwhals/expr_struct.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr @@ -40,6 +42,6 @@ def field(self, name: str) -> ExprT: |└──────────────┴──────┘| └───────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).struct.field(name) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "struct.field", name=name) ) diff --git a/narwhals/functions.py b/narwhals/functions.py index 6baef07688..28ab30a641 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -3,22 +3,14 @@ import platform import sys from collections.abc import Iterable, Mapping, Sequence -from functools import partial -from typing import TYPE_CHECKING, Any, Callable - -from narwhals._expression_parsing import ( - ExprKind, - ExprMetadata, - apply_n_ary_operation, - combine_metadata, - is_scalar_like, -) +from typing import TYPE_CHECKING, Any + +from narwhals._expression_parsing import ExprKind, ExprNode, is_expr, is_series from narwhals._utils import ( Implementation, Version, deprecate_native_namespace, flatten, - is_compliant_expr, is_eager_allowed, is_sequence_but_not_str, normalize_path, @@ -33,7 +25,6 @@ ) from narwhals.exceptions import InvalidOperationError from narwhals.expr import Expr -from narwhals.series import Series from narwhals.translate import from_native, to_native if TYPE_CHECKING: @@ -41,11 +32,11 @@ from typing_extensions import TypeAlias, TypeIs - from narwhals._compliant import CompliantExpr, CompliantNamespace from narwhals._native import NativeDataFrame, NativeLazyFrame, NativeSeries from narwhals._translate import IntoArrowTable from narwhals._typing import Backend, EagerAllowed, IntoBackend from narwhals.dataframe import DataFrame, LazyFrame + from narwhals.series import Series from narwhals.typing import ( ConcatMethod, FileSource, @@ -54,7 +45,6 @@ IntoExpr, IntoSchema, NonNestedLiteral, - _1DArray, _2DArray, ) @@ -957,16 +947,7 @@ def col(*names: str | Iterable[str]) -> Expr: └──────────────────┘ """ flat_names = flatten(names) - - def func(plx: Any) -> Any: - return plx.col(*flat_names) - - return Expr( - func, - ExprMetadata.selector_single() - if len(flat_names) == 1 - else ExprMetadata.selector_multi_named(), - ) + return Expr(ExprNode(ExprKind.COL, "col", names=flat_names)) def exclude(*names: str | Iterable[str]) -> Expr: @@ -995,12 +976,7 @@ def exclude(*names: str | Iterable[str]) -> Expr: | └─────┘ | └──────────────────┘ """ - exclude_names = frozenset(flatten(names)) - - def func(plx: Any) -> Any: - return plx.exclude(exclude_names) - - return Expr(func, ExprMetadata.selector_multi_unnamed()) + return Expr(ExprNode(ExprKind.EXCLUDE, "exclude", names=frozenset(flatten(names)))) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1031,16 +1007,7 @@ def nth(*indices: int | Sequence[int]) -> Expr: └──────────────────┘ """ flat_indices = flatten(indices) - - def func(plx: Any) -> Any: - return plx.nth(*flat_indices) - - return Expr( - func, - ExprMetadata.selector_single() - if len(flat_indices) == 1 - else ExprMetadata.selector_multi_unnamed(), - ) + return Expr(ExprNode(ExprKind.NTH, "nth", indices=flat_indices)) # Add underscore so it doesn't conflict with builtin `all` @@ -1061,7 +1028,7 @@ def all_() -> Expr: | 1 4 0.246 | └──────────────────┘ """ - return Expr(lambda plx: plx.all(), ExprMetadata.selector_multi_unnamed()) + return Expr(ExprNode(ExprKind.ALL, "all")) # Add underscore so it doesn't conflict with builtin `len` @@ -1087,11 +1054,7 @@ def len_() -> Expr: | └─────┘ | └──────────────────┘ """ - - def func(plx: Any) -> Any: - return plx.len() - - return Expr(func, ExprMetadata.aggregation()) + return Expr(ExprNode(ExprKind.AGGREGATION, "len")) def sum(*columns: str) -> Expr: @@ -1235,21 +1198,12 @@ def max(*columns: str) -> Expr: return col(*columns).max() -def _expr_with_n_ary_op( - func_name: str, - operation_factory: Callable[ - [CompliantNamespace[Any, Any]], Callable[..., CompliantExpr[Any, Any]] - ], - *exprs: IntoExpr, -) -> Expr: +def _expr_with_horizontal_op(name: str, *exprs: IntoExpr, **kwargs: Any) -> Expr: if not exprs: - msg = f"At least one expression must be passed to `{func_name}`" + msg = f"At least one expression must be passed to `{name}`" raise ValueError(msg) return Expr( - lambda plx: apply_n_ary_operation( - plx, operation_factory(plx), *exprs, str_as_lit=False - ), - ExprMetadata.from_horizontal_op(*exprs), + ExprNode(ExprKind.ELEMENTWISE, name, *exprs, **kwargs, allow_multi_output=True) ) @@ -1284,9 +1238,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: |└─────┴──────┴─────┘| └────────────────────┘ """ - return _expr_with_n_ary_op( - "sum_horizontal", lambda plx: plx.sum_horizontal, *flatten(exprs) - ) + return _expr_with_horizontal_op("sum_horizontal", *flatten(exprs)) def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -1318,9 +1270,7 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: | h_min: [[1,5,3]] | └──────────────────┘ """ - return _expr_with_n_ary_op( - "min_horizontal", lambda plx: plx.min_horizontal, *flatten(exprs) - ) + return _expr_with_horizontal_op("min_horizontal", *flatten(exprs)) def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -1354,73 +1304,29 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: |└─────┴──────┴───────┘| └──────────────────────┘ """ - return _expr_with_n_ary_op( - "max_horizontal", lambda plx: plx.max_horizontal, *flatten(exprs) - ) + return _expr_with_horizontal_op("max_horizontal", *flatten(exprs)) class When: def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: self._predicate = all_horizontal(*flatten(predicates), ignore_nulls=False) - def then(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Then: - kind = ExprKind.from_into_expr(value, str_as_lit=False) - if self._predicate._metadata.is_scalar_like and not kind.is_scalar_like: - msg = ( - "If you pass a scalar-like predicate to `nw.when`, then " - "the `then` value must also be scalar-like." - ) - raise InvalidOperationError(msg) - + def then(self, value: IntoExpr | NonNestedLiteral) -> Then: return Then( - lambda plx: apply_n_ary_operation( - plx, - lambda *args: plx.when(args[0]).then(args[1]), + ExprNode( + ExprKind.ELEMENTWISE, + "when_then", self._predicate, value, - str_as_lit=False, - ), - combine_metadata( - self._predicate, - value, - str_as_lit=False, allow_multi_output=False, - to_single_output=False, - ), + ) ) class Then(Expr): - def otherwise(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Expr: - kind = ExprKind.from_into_expr(value, str_as_lit=False) - if self._metadata.is_scalar_like and not is_scalar_like(kind): - msg = ( - "If you pass a scalar-like predicate to `nw.when`, then " - "the `otherwise` value must also be scalar-like." - ) - raise InvalidOperationError(msg) - - def func(plx: CompliantNamespace[Any, Any]) -> CompliantExpr[Any, Any]: - compliant_expr = self._to_compliant_expr(plx) - compliant_value = plx.parse_into_expr(value, str_as_lit=False) - if ( - not self._metadata.is_scalar_like - and is_scalar_like(kind) - and is_compliant_expr(compliant_value) - ): - compliant_value = compliant_value.broadcast(kind) - return compliant_expr.otherwise(compliant_value) # type: ignore[attr-defined, no-any-return] - - return Expr( - func, - combine_metadata( - self, - value, - str_as_lit=False, - allow_multi_output=False, - to_single_output=False, - ), - ) + def otherwise(self, value: IntoExpr | NonNestedLiteral) -> Expr: + node = self._nodes[0] + return Expr(ExprNode(ExprKind.ELEMENTWISE, "when_then", *node.exprs, value)) def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: @@ -1503,10 +1409,8 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> |all: [[false,false,true,null,false,null]]| └─────────────────────────────────────────┘ """ - return _expr_with_n_ary_op( - "all_horizontal", - lambda plx: partial(plx.all_horizontal, ignore_nulls=ignore_nulls), - *flatten(exprs), + return _expr_with_horizontal_op( + "all_horizontal", *flatten(exprs), ignore_nulls=ignore_nulls ) @@ -1543,7 +1447,7 @@ def lit(value: NonNestedLiteral, dtype: IntoDType | None = None) -> Expr: msg = f"Nested datatypes are not supported yet. Got {value}" raise NotImplementedError(msg) - return Expr(lambda plx: plx.lit(value, dtype), ExprMetadata.literal()) + return Expr(ExprNode(ExprKind.LITERAL, "lit", value=value, dtype=dtype)) def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> Expr: @@ -1589,10 +1493,8 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> |└───────┴───────┴───────┘| └─────────────────────────┘ """ - return _expr_with_n_ary_op( - "any_horizontal", - lambda plx: partial(plx.any_horizontal, ignore_nulls=ignore_nulls), - *flatten(exprs), + return _expr_with_horizontal_op( + "any_horizontal", *flatten(exprs), ignore_nulls=ignore_nulls ) @@ -1623,9 +1525,7 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: | a: [[2.5,6.5,3]] | └──────────────────┘ """ - return _expr_with_n_ary_op( - "mean_horizontal", lambda plx: plx.mean_horizontal, *flatten(exprs) - ) + return _expr_with_horizontal_op("mean_horizontal", *flatten(exprs)) def concat_str( @@ -1674,12 +1574,8 @@ def concat_str( └──────────────────┘ """ flat_exprs = flatten([*flatten([exprs]), *more_exprs]) - return _expr_with_n_ary_op( - "concat_str", - lambda plx: lambda *args: plx.concat_str( - *args, separator=separator, ignore_nulls=ignore_nulls - ), - *flat_exprs, + return _expr_with_horizontal_op( + "concat_str", *flat_exprs, separator=separator, ignore_nulls=ignore_nulls ) @@ -1729,21 +1625,20 @@ def coalesce( """ flat_exprs = flatten([*flatten([exprs]), *more_exprs]) - non_exprs = [expr for expr in flat_exprs if not isinstance(expr, (str, Expr, Series))] + non_exprs = [ + expr + for expr in flat_exprs + if not (isinstance(expr, str) or is_expr(expr) or is_series(expr)) + ] if non_exprs: msg = ( - f"All arguments to `coalesce` must be of type {str!r}, {Expr!r}, or {Series!r}." + f"All arguments to `coalesce` must be of type {str!r}, Expr, or Series." "\nGot the following invalid arguments (type, value):" f"\n {', '.join(repr((type(e), e)) for e in non_exprs)}" ) raise TypeError(msg) - return Expr( - lambda plx: apply_n_ary_operation( - plx, lambda *args: plx.coalesce(*args), *flat_exprs, str_as_lit=False - ), - ExprMetadata.from_horizontal_op(*flat_exprs), - ) + return Expr(ExprNode(ExprKind.ELEMENTWISE, "coalesce", *flat_exprs)) def format(f_string: str, *args: IntoExpr) -> Expr: diff --git a/narwhals/group_by.py b/narwhals/group_by.py index c469ac921e..65b9a9b48b 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar -from narwhals._expression_parsing import all_exprs_are_scalar_like -from narwhals._utils import flatten, tupleify +from narwhals._expression_parsing import is_scalar_like +from narwhals._utils import tupleify from narwhals.exceptions import InvalidOperationError from narwhals.typing import DataFrameT @@ -72,8 +72,8 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT: 2 b 3 2 3 c 3 1 """ - flat_aggs = tuple(flatten(aggs)) - if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs): + compliant_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) + if not all(is_scalar_like(x) for x in compliant_aggs): msg = ( "Found expression which does not aggregate.\n\n" "All expressions passed to GroupBy.agg must aggregate.\n" @@ -81,14 +81,6 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT: "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - plx = self._df.__narwhals_namespace__() - compliant_aggs = ( - *(x._to_compliant_expr(plx) for x in flat_aggs), - *( - value.alias(key)._to_compliant_expr(plx) - for key, value in named_aggs.items() - ), - ) return self._df._with_compliant(self._grouped.agg(*compliant_aggs)) def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: @@ -166,8 +158,8 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT: |└─────┴─────┴─────┘| └───────────────────┘ """ - flat_aggs = tuple(flatten(aggs)) - if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs): + compliant_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) + if not all(is_scalar_like(x) for x in compliant_aggs): msg = ( "Found expression which does not aggregate.\n\n" "All expressions passed to GroupBy.agg must aggregate.\n" @@ -175,12 +167,4 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT: "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - plx = self._df.__narwhals_namespace__() - compliant_aggs = ( - *(x._to_compliant_expr(plx) for x in flat_aggs), - *( - value.alias(key)._to_compliant_expr(plx) - for key, value in named_aggs.items() - ), - ) return self._df._with_compliant(self._grouped.agg(*compliant_aggs)) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index 7eeeb356c9..bd34a00199 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, NoReturn -from narwhals._expression_parsing import ExprMetadata, combine_metadata +from narwhals._expression_parsing import ExprKind, ExprNode from narwhals._utils import flatten from narwhals.expr import Expr @@ -16,41 +16,45 @@ class Selector(Expr): def _to_expr(self) -> Expr: - return Expr(self._to_compliant_expr, self._metadata) + return Expr(*self._nodes) def __add__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" raise TypeError(msg) - return self._to_expr() + other # type: ignore[no-any-return] + return self._to_expr()._append_node( + ExprNode(ExprKind.ELEMENTWISE, "__add__", other, str_as_lit=True) + ) def __or__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): - return self.__class__( - lambda plx: self._to_compliant_expr(plx) | other._to_compliant_expr(plx), - combine_metadata( - self, + return self._append_node( + ExprNode( + ExprKind.ELEMENTWISE, + "__or__", other, - str_as_lit=False, + str_as_lit=True, allow_multi_output=True, - to_single_output=False, - ), + ) ) - return self._to_expr() | other # type: ignore[no-any-return] + return self._to_expr()._append_node( + ExprNode(ExprKind.ELEMENTWISE, "__or__", other, str_as_lit=True) + ) def __and__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): - return self.__class__( - lambda plx: self._to_compliant_expr(plx) & other._to_compliant_expr(plx), - combine_metadata( - self, + return self._append_node( + ExprNode( + ExprKind.ELEMENTWISE, + "__and__", other, - str_as_lit=False, + str_as_lit=True, allow_multi_output=True, - to_single_output=False, - ), + ) ) - return self._to_expr() & other # type: ignore[no-any-return] + return self._to_expr()._append_node( + ExprNode(ExprKind.ELEMENTWISE, "__and__", other, str_as_lit=True) + ) def __rsub__(self, other: Any) -> NoReturn: raise NotImplementedError @@ -86,10 +90,7 @@ def by_dtype(*dtypes: DType | type[DType] | Iterable[DType | type[DType]]) -> Se c: [[8.2,4.6]] """ flattened = flatten(dtypes) - return Selector( - lambda plx: plx.selectors.by_dtype(flattened), - ExprMetadata.selector_multi_unnamed(), - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.by_dtype", dtypes=flattened)) def matches(pattern: str) -> Selector: @@ -114,9 +115,7 @@ def matches(pattern: str) -> Selector: 0 123 2.0 1 456 5.5 """ - return Selector( - lambda plx: plx.selectors.matches(pattern), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.matches", pattern=pattern)) def numeric() -> Selector: @@ -142,9 +141,7 @@ def numeric() -> Selector: │ 4 ┆ 4.6 │ └─────┴─────┘ """ - return Selector( - lambda plx: plx.selectors.numeric(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.numeric")) def boolean() -> Selector: @@ -174,9 +171,7 @@ def boolean() -> Selector: | └───────┘ | └──────────────────┘ """ - return Selector( - lambda plx: plx.selectors.boolean(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.boolean")) def string() -> Selector: @@ -202,9 +197,7 @@ def string() -> Selector: │ y │ └─────┘ """ - return Selector( - lambda plx: plx.selectors.string(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.string")) def categorical() -> Selector: @@ -232,9 +225,7 @@ def categorical() -> Selector: │ y │ └─────┘ """ - return Selector( - lambda plx: plx.selectors.categorical(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.categorical")) def all() -> Selector: @@ -254,9 +245,7 @@ def all() -> Selector: 0 1 x False 1 2 y True """ - return Selector( - lambda plx: plx.selectors.all(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.all")) def datetime( @@ -312,8 +301,12 @@ def datetime( tstamp_utc: [[2023-04-10 12:14:16.999000Z,2025-08-25 14:18:22.666000Z]] """ return Selector( - lambda plx: plx.selectors.datetime(time_unit=time_unit, time_zone=time_zone), - ExprMetadata.selector_multi_unnamed(), + ExprNode( + ExprKind.SELECTOR, + "selectors.datetime", + time_unit=time_unit, + time_zone=time_zone, + ) ) diff --git a/narwhals/series.py b/narwhals/series.py index 0eea411e28..2e34b49913 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2,8 +2,18 @@ import math from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Generic, + Literal, + cast, + overload, +) +from narwhals._expression_parsing import ExprKind, ExprNode from narwhals._utils import ( Implementation, Version, @@ -20,6 +30,8 @@ from narwhals.dependencies import is_numpy_array, is_numpy_array_1d, is_numpy_scalar from narwhals.dtypes import _validate_dtype, _validate_into_dtype from narwhals.exceptions import ComputeError, InvalidOperationError +from narwhals.expr import Expr +from narwhals.functions import col from narwhals.series_cat import SeriesCatNamespace from narwhals.series_dt import SeriesDateTimeNamespace from narwhals.series_list import SeriesListNamespace @@ -90,6 +102,11 @@ def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame + def _to_expr(self) -> Expr: + return Expr( + ExprNode(ExprKind.SERIES, "_expr._from_series", series=self._compliant) + ) + def __init__( self, series: Any, *, level: Literal["full", "lazy", "interchange"] ) -> None: @@ -881,6 +898,18 @@ def clip( 5 3 dtype: int64 """ + if lower_bound is None: + return self._with_compliant( + self._compliant_series.clip_upper( + upper_bound=self._extract_native(upper_bound) + ) + ) + if upper_bound is None: + return self._with_compliant( + self._compliant_series.clip_lower( + lower_bound=self._extract_native(lower_bound) + ) + ) return self._with_compliant( self._compliant_series.clip( lower_bound=self._extract_native(lower_bound), @@ -2769,23 +2798,12 @@ def is_close( "Hint: `is_close` is only supported for numeric types" ) raise InvalidOperationError(msg) - - if abs_tol < 0: - msg = f"`abs_tol` must be non-negative but got {abs_tol}" - raise ComputeError(msg) - - if not (0 <= rel_tol < 1): - msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}" - raise ComputeError(msg) - - return self._with_compliant( - self._compliant_series.is_close( - self._extract_native(other), - abs_tol=abs_tol, - rel_tol=rel_tol, - nans_equal=nans_equal, + ret_df = self.to_frame().select( + col(self.name).is_close( + other, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal ) ) + return cast("Self", ret_df[self.name]) @property def str(self) -> SeriesStringNamespace[Self]: diff --git a/narwhals/series_str.py b/narwhals/series_str.py index ae98d4db34..d4c32dc97e 100644 --- a/narwhals/series_str.py +++ b/narwhals/series_str.py @@ -57,7 +57,7 @@ def replace( """ return self._narwhals_series._with_compliant( self._narwhals_series._compliant_series.str.replace( - pattern, self._extract_compliant(value), literal=literal, n=n + self._extract_compliant(value), pattern=pattern, literal=literal, n=n ) ) @@ -83,7 +83,7 @@ def replace_all( """ return self._narwhals_series._with_compliant( self._narwhals_series._compliant_series.str.replace_all( - pattern, self._extract_compliant(value), literal=literal + self._extract_compliant(value), pattern, literal=literal ) ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index e97e3dc7d7..42d06ce928 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -6,6 +6,7 @@ import narwhals as nw from narwhals import exceptions, functions as nw_f from narwhals._exceptions import issue_warning +from narwhals._expression_parsing import ExprKind, ExprNode, is_expr from narwhals._typing_compat import TypeVar, assert_never from narwhals._utils import ( Implementation, @@ -67,6 +68,7 @@ from typing_extensions import ParamSpec, Self + from narwhals._expression_parsing import ExprMetadata from narwhals._translate import IntoArrowTable from narwhals._typing import ( Arrow, @@ -246,18 +248,9 @@ def __init__(self, df: Any, *, level: Literal["full", "lazy", "interchange"]) -> def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame - def _extract_compliant(self, arg: Any) -> Any: - # After v1, we raise when passing order-dependent, length-changing, - # or filtration expressions to LazyFrame - from narwhals.expr import Expr - from narwhals.series import Series - - if isinstance(arg, Series): # pragma: no cover - msg = "Mixing Series with LazyFrame is not supported." - raise TypeError(msg) - if isinstance(arg, (Expr, str)): - return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False) - raise InvalidIntoExprError.from_invalid_type(type(arg)) + def _validate_metadata(self, metadata: ExprMetadata) -> None: + # After v1, we raise for order-dependent operations. + pass def collect( self, backend: IntoBackend[Polars | Pandas | Arrow] | None = None, **kwargs: Any @@ -380,15 +373,11 @@ def _l1_norm(self) -> Self: def head(self, n: int = 10) -> Self: r"""Get the first `n` rows.""" - return self._with_orderable_filtration( - lambda plx: self._to_compliant_expr(plx).head(n) # type: ignore[attr-defined] - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_FILTRATION, "head", n=n)) def tail(self, n: int = 10) -> Self: r"""Get the last `n` rows.""" - return self._with_orderable_filtration( - lambda plx: self._to_compliant_expr(plx).tail(n) # type: ignore[attr-defined] - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_FILTRATION, "tail", n=n)) def gather_every(self, n: int, offset: int = 0) -> Self: r"""Take every nth value in the Series and return as new Series. @@ -397,8 +386,8 @@ def gather_every(self, n: int, offset: int = 0) -> Self: n: Gather every *n*-th row. offset: Starting index. """ - return self._with_orderable_filtration( - lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset) # type: ignore[attr-defined] + return self._append_node( + ExprNode(ExprKind.ORDERABLE_FILTRATION, "gather_every", n=n, offset=offset) ) def unique(self, *, maintain_order: bool | None = None) -> Self: @@ -409,33 +398,27 @@ def unique(self, *, maintain_order: bool | None = None) -> Self: "You can safely remove this argument." ) issue_warning(msg, UserWarning) - return self._with_filtration(lambda plx: self._to_compliant_expr(plx).unique()) + return self._append_node(ExprNode(ExprKind.FILTRATION, "unique")) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: """Sort this column. Place null values first.""" - return self._with_window( - lambda plx: self._to_compliant_expr(plx).sort( # type: ignore[attr-defined] - descending=descending, nulls_last=nulls_last + return self._append_node( + ExprNode( + ExprKind.WINDOW, "sort", descending=descending, nulls_last=nulls_last ) ) def arg_max(self) -> Self: """Returns the index of the maximum value.""" - return self._with_orderable_aggregation( - lambda plx: self._to_compliant_expr(plx).arg_max() # type: ignore[attr-defined] - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "arg_max")) def arg_min(self) -> Self: """Returns the index of the minimum value.""" - return self._with_orderable_aggregation( - lambda plx: self._to_compliant_expr(plx).arg_min() # type: ignore[attr-defined] - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "arg_min")) def arg_true(self) -> Self: """Find elements where boolean expression is True.""" - return self._with_orderable_filtration( - lambda plx: self._to_compliant_expr(plx).arg_true() # type: ignore[attr-defined] - ) + return self._append_node(ExprNode(ExprKind.ORDERABLE_FILTRATION, "arg_true")) def sample( self, @@ -454,9 +437,14 @@ def sample( seed: Seed for the random number generator. If set to None (default), a random seed is generated for each sample operation. """ - return self._with_filtration( - lambda plx: self._to_compliant_expr(plx).sample( # type: ignore[attr-defined] - n, fraction=fraction, with_replacement=with_replacement, seed=seed + return self._append_node( + ExprNode( + ExprKind.FILTRATION, + "sample", + n=n, + fraction=fraction, + with_replacement=with_replacement, + seed=seed, ) ) @@ -494,7 +482,7 @@ def _stableify( if isinstance(obj, NwSeries): return Series(obj._compliant_series._with_version(Version.V1), level=obj._level) if isinstance(obj, NwExpr): - return Expr(obj._to_compliant_expr, obj._metadata) + return Expr(*obj._nodes) assert_never(obj) @@ -1215,7 +1203,7 @@ def then(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Then: class Then(nw_f.Then, Expr): @classmethod def from_then(cls, then: nw_f.Then) -> Then: - return cls(then._to_compliant_expr, then._metadata) + return cls(*then._nodes) def otherwise(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Expr: return _stableify(super().otherwise(value)) @@ -1392,6 +1380,7 @@ def scan_parquet( "Int32", "Int64", "Int128", + "InvalidIntoExprError", "LazyFrame", "List", "Object", @@ -1426,6 +1415,7 @@ def scan_parquet( "generate_temporary_column_name", "get_level", "get_native_namespace", + "is_expr", "is_ordered_categorical", "len", "lit", diff --git a/narwhals/stable/v2/__init__.py b/narwhals/stable/v2/__init__.py index d85b129b8a..fd342e9b92 100644 --- a/narwhals/stable/v2/__init__.py +++ b/narwhals/stable/v2/__init__.py @@ -340,7 +340,7 @@ def _stableify( if isinstance(obj, NwSeries): return Series(obj._compliant_series._with_version(Version.V2), level=obj._level) if isinstance(obj, NwExpr): - return Expr(obj._to_compliant_expr, obj._metadata) + return Expr(*obj._nodes) assert_never(obj) @@ -953,7 +953,7 @@ def then(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Then: class Then(nw_f.Then, Expr): @classmethod def from_then(cls, then: nw_f.Then) -> Then: - return cls(then._to_compliant_expr, then._metadata) + return cls(*then._nodes) def otherwise(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Expr: return _stableify(super().otherwise(value)) diff --git a/narwhals/typing.py b/narwhals/typing.py index 2425d01ebe..8c89b7745b 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -98,7 +98,17 @@ def Time(self) -> type[dtypes.Time]: ... def Binary(self) -> type[dtypes.Binary]: ... -IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]"] +_ShapeT = TypeVar("_ShapeT", bound="tuple[int, ...]") +_NDArray: TypeAlias = "np.ndarray[_ShapeT, Any]" +_1DArray: TypeAlias = "_NDArray[tuple[int]]" +_1DArrayInt: TypeAlias = "np.ndarray[tuple[int], np.dtype[np.integer[Any]]]" +_2DArray: TypeAlias = "_NDArray[tuple[int, int]]" # noqa: PYI047 +_AnyDArray: TypeAlias = "_NDArray[tuple[int, ...]]" # noqa: PYI047 +_NumpyScalar: TypeAlias = "np.generic[Any]" +Into1DArray: TypeAlias = "_1DArray | _NumpyScalar" +"""A 1-dimensional `numpy.ndarray` or scalar that can be converted into one.""" + +IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]", _1DArray] """Anything which can be converted to an expression. Use this to mean "either a Narwhals expression, or something which can be converted @@ -250,15 +260,6 @@ def Binary(self) -> type[dtypes.Binary]: ... - *"all"*: Keeps all the mode's. """ -_ShapeT = TypeVar("_ShapeT", bound="tuple[int, ...]") -_NDArray: TypeAlias = "np.ndarray[_ShapeT, Any]" -_1DArray: TypeAlias = "_NDArray[tuple[int]]" -_1DArrayInt: TypeAlias = "np.ndarray[tuple[int], np.dtype[np.integer[Any]]]" -_2DArray: TypeAlias = "_NDArray[tuple[int, int]]" # noqa: PYI047 -_AnyDArray: TypeAlias = "_NDArray[tuple[int, ...]]" # noqa: PYI047 -_NumpyScalar: TypeAlias = "np.generic[Any]" -Into1DArray: TypeAlias = "_1DArray | _NumpyScalar" -"""A 1-dimensional `numpy.ndarray` or scalar that can be converted into one.""" PandasLikeDType: TypeAlias = "pd.api.extensions.ExtensionDtype | np.dtype[Any]" diff --git a/tests/conftest.py b/tests/conftest.py index 650b8d4a3a..e4c3d26b88 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -130,6 +130,7 @@ def cudf_constructor(obj: Data) -> IntoDataFrame: # pragma: no cover def polars_eager_constructor(obj: Data) -> pl.DataFrame: + pytest.importorskip("polars") import polars as pl return pl.DataFrame(obj) @@ -222,7 +223,7 @@ def ibis_lazy_constructor(obj: Data) -> ibis.Table: # pragma: no cover pytest.importorskip("polars") import polars as pl - ldf = pl.from_dict(obj).lazy() + ldf = pl.LazyFrame(obj) table_name = str(uuid.uuid4()) return _ibis_backend().create_table(table_name, ldf) diff --git a/tests/expr_and_series/fill_null_test.py b/tests/expr_and_series/fill_null_test.py index 139c862fff..7200ff06d5 100644 --- a/tests/expr_and_series/fill_null_test.py +++ b/tests/expr_and_series/fill_null_test.py @@ -8,6 +8,7 @@ import narwhals as nw from tests.utils import ( + DASK_VERSION, DUCKDB_VERSION, POLARS_VERSION, Constructor, @@ -34,6 +35,9 @@ def test_fill_null(constructor: Constructor) -> None: def test_fill_null_w_aggregate(constructor: Constructor) -> None: + if "dask" in str(constructor) and DASK_VERSION < (2024, 12): + # Bug in old version of Dask. + pytest.skip() if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): pytest.skip() data = {"a": [0.5, None, 2.0, 3.0, 4.5], "b": ["xx", "yy", "zz", None, "yy"]} diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index 24f5a32457..b5457bf754 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -385,7 +385,7 @@ def test_over_cum_reverse( def test_over_raise_len_change(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): + with pytest.raises((InvalidOperationError, NotImplementedError)): nw.from_native(df).select(nw.col("b").drop_nulls().over("a")) diff --git a/tests/expr_and_series/unique_test.py b/tests/expr_and_series/unique_test.py index faeadf8bcb..c54a0356c7 100644 --- a/tests/expr_and_series/unique_test.py +++ b/tests/expr_and_series/unique_test.py @@ -15,7 +15,7 @@ def test_unique_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) context = ( - pytest.raises(InvalidOperationError) + pytest.raises((InvalidOperationError, NotImplementedError)) if isinstance(df, nw.LazyFrame) else does_not_raise() ) @@ -41,10 +41,10 @@ def test_unique_expr_agg( def test_unique_illegal_combination(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): - df.select((nw.col("a").unique() + nw.col("b").unique()).sum()) - with pytest.raises(InvalidOperationError): - df.select(nw.col("a").unique() + nw.col("b")) + with pytest.raises((InvalidOperationError, NotImplementedError)): + df.select((nw.col("a").unique() + nw.col("a").unique()).sum()) + with pytest.raises((InvalidOperationError, NotImplementedError)): + df.select(nw.col("a").unique() + nw.col("a")) def test_unique_series(constructor_eager: ConstructorEager) -> None: diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 46b07c66d7..cbb908bf64 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -6,7 +6,7 @@ import pytest import narwhals as nw -from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError +from narwhals.exceptions import MultiOutputExpressionError from tests.utils import DUCKDB_VERSION, Constructor, ConstructorEager, assert_equal_data if TYPE_CHECKING: @@ -115,13 +115,16 @@ def test_when_then_otherwise_into_expr(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_when_then_invalid(constructor: Constructor) -> None: +def test_when_then_broadcasting(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): - df.select(nw.when(nw.col("a").sum() > 1).then("c")) - - with pytest.raises(InvalidOperationError): - df.select(nw.when(nw.col("a").sum() > 1).then(1).otherwise("c")) + result = df.select(nw.when(nw.col("a").sum() > 1).then("c")) + expected = {"c": [4.1, 5, 6]} + assert_equal_data(result, expected) + result = df.select(nw.when(nw.col("a").sum() > 1).then(1).otherwise("c")) + expected = {"literal": [1, 1, 1]} + assert_equal_data(result, expected) def test_when_then_otherwise_lit_str(constructor: Constructor) -> None: diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index 79b7f89b70..138063eb06 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -4,96 +4,133 @@ import narwhals as nw from narwhals.exceptions import InvalidOperationError +from tests.utils import DUCKDB_VERSION, POLARS_VERSION, Constructor, assert_equal_data @pytest.mark.parametrize( ("expr", "expected"), [ - (nw.col("a"), 0), - (nw.col("a").mean(), 0), - (nw.col("a").cum_sum(), 1), - (nw.col("a").cum_sum().over(order_by="id"), 0), - (nw.col("a").cum_sum().abs().over(order_by="id"), 1), - ((nw.col("a").cum_sum() + 1).over(order_by="id"), 1), - (nw.col("a").cum_sum().cum_sum().over(order_by="id"), 1), - (nw.col("a").cum_sum().cum_sum(), 2), - (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()), 1), - (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over(order_by="a"), 1), - (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum().over(order_by="i")), 0), + (nw.col("a"), [-1, 2, 3]), + (nw.col("a").mean(), [4 / 3, 4 / 3, 4 / 3]), + (nw.col("a").cum_sum().over(order_by="i"), [-1, 1, 4]), + (nw.col("a").alias("b").cum_sum().over(order_by="i"), [-1, 1, 4]), + (nw.col("a").cum_sum().abs().over(order_by="i"), [1, 1, 4]), + ((nw.col("a").cum_sum() + 1).over(order_by="i"), [0, 2, 5]), ( - nw.sum_horizontal( - nw.col("a").diff(), nw.col("a").cum_sum().over(order_by="i") - ), - 1, + nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over(order_by="a"), + [-2, 3, 7], + ), + ( + nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum().over(order_by="i")), + [-2, 3, 7], ), ( nw.sum_horizontal(nw.col("a").diff(), nw.col("a").cum_sum()).over( order_by="i" ), - 2, + [-1.0, 4.0, 5.0], ), ( nw.sum_horizontal(nw.col("a").diff().abs(), nw.col("a").cum_sum()).over( order_by="i" ), - 2, + [-1.0, 4.0, 5.0], + ), + ( + (nw.col("a").sum() + nw.col("a").rolling_sum(2, min_samples=1)).over( + order_by="i" + ), + [3.0, 5.0, 9.0], + ), + ((nw.col("a").sum() + nw.col("a").mean()).over("b"), [1.5, 1.5, 6.0]), + ( + (nw.col("a").mean().abs() + nw.sum_horizontal(nw.col("a").diff())).over( + order_by="i" + ), + [4 / 3, 13 / 3, 7 / 3], ), ], ) -def test_window_kind(expr: nw.Expr, expected: int) -> None: - assert expr._metadata.n_orderable_ops == expected - - -def test_misleading_order_by() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().over(order_by="b") - - -def test_double_over() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().over("b").over("c") - - -def test_double_agg() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().mean() - with pytest.raises(InvalidOperationError): - nw.col("a").mean().sum() +def test_over_pushdown( + constructor: Constructor, expr: nw.Expr, expected: list[float] +) -> None: + if "polars" in str(constructor) and POLARS_VERSION < (1, 10): + pytest.skip() + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() + data = {"a": [-1, 2, 3], "b": [1, 1, 2], "i": [0, 1, 2]} + df = nw.from_native(constructor(data)).lazy() + result = df.select("i", a=expr).sort("i").select("a") + assert_equal_data(result, {"a": expected}) -def test_filter_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().drop_nulls() - - -def test_rank_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().rank() - with pytest.raises(InvalidOperationError): - nw.col("a").mean().is_unique() - - -def test_diff_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().diff() +@pytest.mark.parametrize( + ("expr", "expected"), [((nw.col("a") - nw.col("a").mean()).over("b"), [-1.5, 1.5, 0])] +) +def test_per_group_broadcasting( + constructor: Constructor, + expr: nw.Expr, + expected: list[float], + request: pytest.FixtureRequest, +) -> None: + if "dask" in str(constructor): + # sigh... + request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() + data = {"a": [-1, 2, 3], "b": [1, 1, 2], "i": [0, 1, 2]} + df = nw.from_native(constructor(data)).lazy() + result = df.select("i", a=expr).sort("i").select("a") + assert_equal_data(result, {"a": expected}) -def test_invalid_over() -> None: +@pytest.mark.parametrize( + "expr", + [ + nw.col("a").cum_sum(), + nw.col("a").cum_sum().cum_sum().over(order_by="i"), + nw.col("a").cum_sum().cum_sum(), + nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()), + nw.sum_horizontal(nw.col("a").diff(), nw.col("a").cum_sum().over(order_by="i")), + nw.col("a").mean().over(order_by="i"), + nw.col("a").mean().over("b").over("c"), + nw.col("a").mean().over("b").over("c", order_by="i"), + nw.col("a").mean().mean(), + nw.col("a").mean().sum(), + nw.col("a").mean().drop_nulls(), + nw.col("a").mean().rank(), + nw.col("a").mean().is_unique(), + nw.col("a").mean().diff(), + nw.col("a").drop_nulls().over("b"), + nw.col("a").drop_nulls().over("b", order_by="i"), + nw.col("a").diff().drop_nulls().over("b", order_by="i"), + nw.col("a").filter(nw.col("b").sum().over("c") > 1).sum().over("d"), + ], +) +def test_invalid_operations(constructor: Constructor, expr: nw.Expr) -> None: + if "polars" in str(constructor) and POLARS_VERSION < (1, 10): + pytest.skip() + df = nw.from_native( + constructor({"a": [-1, 2, 3], "b": [1, 1, 1], "c": [2, 2, 2], "i": [0, 1, 2]}) + ).lazy() + with pytest.raises((InvalidOperationError, NotImplementedError)): + df.select(a=expr) + + +def test_invalid_elementwise_over() -> None: + # This one raises before it's even evaluated. with pytest.raises(InvalidOperationError): nw.col("a").fill_null(3).over("b") -def test_nested_over() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().over("b").over("c") - with pytest.raises(InvalidOperationError): - nw.col("a").mean().over("b").over("c", order_by="i") - +def test_rank_with_order_by_pushdown() -> None: + pytest.importorskip("pandas") + import pandas as pd -def test_filtration_over() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").drop_nulls().over("b") - with pytest.raises(InvalidOperationError): - nw.col("a").drop_nulls().over("b", order_by="i") - with pytest.raises(InvalidOperationError): - nw.col("a").diff().drop_nulls().over("b", order_by="i") + df = nw.from_native(pd.DataFrame({"a": [1, 1, 2], "i": [2, 1, 0]})) + result = df.select( + "a", + res=nw.sum_horizontal(nw.col("a").rank("ordinal"), nw.lit(1)).over(order_by="i"), + ) + expected = {"a": [1, 1, 2], "res": [3.0, 2.0, 4.0]} + assert_equal_data(result, expected) diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index 1824dfa7b7..f8048d8e4c 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -81,7 +81,7 @@ def test_filter_raise_on_agg_predicate(constructor: Constructor) -> None: def test_filter_raise_on_shape_mismatch(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): + with pytest.raises((InvalidOperationError, NotImplementedError)): df.filter(nw.col("b").unique() > 2).lazy().collect() diff --git a/tests/frame/group_by_test.py b/tests/frame/group_by_test.py index f6b513af98..d665ad9292 100644 --- a/tests/frame/group_by_test.py +++ b/tests/frame/group_by_test.py @@ -371,7 +371,7 @@ def test_group_by_shift_raises(constructor: Constructor) -> None: df_native = {"a": [1, 2, 3], "b": [1, 1, 2]} df = nw.from_native(constructor(df_native)) with pytest.raises(InvalidOperationError, match="does not aggregate"): - df.group_by("b").agg(nw.col("a").shift(1)) + df.group_by("b").agg(nw.col("a").abs()) def test_double_same_aggregation( @@ -534,7 +534,7 @@ def test_group_by_raise_if_not_preserves_length( ) -> None: data = {"a": [1, 2, 2, None], "b": [0, 1, 2, 3], "x": [1, 2, 3, 4]} df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): + with pytest.raises((InvalidOperationError, NotImplementedError)): df.group_by(keys).agg(nw.col("x").max()) diff --git a/tests/frame/lazy_test.py b/tests/frame/lazy_test.py index 8a16d48bad..0d5020049c 100644 --- a/tests/frame/lazy_test.py +++ b/tests/frame/lazy_test.py @@ -55,6 +55,7 @@ def test_lazy_to_default(constructor_eager: ConstructorEager) -> None: assert isinstance(result.to_native(), expected_cls) +@pytest.mark.slow @pytest.mark.parametrize( "backend", [ diff --git a/tests/repr_test.py b/tests/repr_test.py index 2ef9f77271..62100bee98 100644 --- a/tests/repr_test.py +++ b/tests/repr_test.py @@ -100,3 +100,19 @@ def test_polars_series_repr() -> None: "└────────────────────┘" ) assert result == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nw.col("a"), "col(a)"), + (nw.col("a").abs(), "col(a).abs()"), + (nw.col("a").std(ddof=2), "col(a).std(ddof=2)"), + ( + nw.sum_horizontal(nw.col("a").rolling_mean(2), "b"), + "sum_horizontal(col(a).rolling_mean(window_size=2, min_samples=2, center=False), b)", + ), + ], +) +def test_expr_repr(expr: nw.Expr, expected: str) -> None: + assert repr(expr) == expected diff --git a/tests/v1_test.py b/tests/v1_test.py index 67660f3614..a533f65b05 100644 --- a/tests/v1_test.py +++ b/tests/v1_test.py @@ -902,9 +902,12 @@ def test_unique_series_v1() -> None: series.to_frame().select(nw_v1.col("a").unique(maintain_order=False).sum()) -def test_head_aggregation() -> None: +def test_invalid() -> None: + df = nw.from_native(pd.DataFrame({"a": [1, 2]})) with pytest.raises(InvalidOperationError): - nw_v1.col("a").mean().head() + df.select(nw_v1.col("a").mean().head()) + with pytest.raises(InvalidOperationError): + df.select(nw_v1.col("a").mean().arg_true()) def test_deprecated_expr_methods() -> None: