diff --git a/narwhals/_arrow/series_cat.py b/narwhals/_arrow/series_cat.py index 944f3390e3..23a08bf483 100644 --- a/narwhals/_arrow/series_cat.py +++ b/narwhals/_arrow/series_cat.py @@ -5,13 +5,14 @@ import pyarrow as pa from narwhals._arrow.utils import ArrowSeriesNamespace +from narwhals._compliant.any_namespace import CatNamespace if TYPE_CHECKING: from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import Incomplete -class ArrowSeriesCatNamespace(ArrowSeriesNamespace): +class ArrowSeriesCatNamespace(ArrowSeriesNamespace, CatNamespace["ArrowSeries"]): def get_categories(self) -> ArrowSeries: # NOTE: Should be `list[pa.DictionaryArray]`, but `DictionaryArray` has no attributes chunks: Incomplete = self.native.chunks diff --git a/narwhals/_arrow/series_dt.py b/narwhals/_arrow/series_dt.py index 170f6dd356..6f29767f14 100644 --- a/narwhals/_arrow/series_dt.py +++ b/narwhals/_arrow/series_dt.py @@ -6,6 +6,7 @@ import pyarrow.compute as pc from narwhals._arrow.utils import UNITS_DICT, ArrowSeriesNamespace, floordiv_compat, lit +from narwhals._compliant.any_namespace import DateTimeNamespace from narwhals._constants import ( MS_PER_MINUTE, MS_PER_SECOND, @@ -36,7 +37,9 @@ IntoRhs: TypeAlias = int -class ArrowSeriesDateTimeNamespace(ArrowSeriesNamespace): +class ArrowSeriesDateTimeNamespace( + ArrowSeriesNamespace, DateTimeNamespace["ArrowSeries"] +): _TIMESTAMP_DATE_FACTOR: ClassVar[Mapping[TimeUnit, int]] = { "ns": NS_PER_SECOND, "us": US_PER_SECOND, diff --git a/narwhals/_arrow/series_list.py b/narwhals/_arrow/series_list.py index 4ee663ba20..defad3dad6 100644 --- a/narwhals/_arrow/series_list.py +++ b/narwhals/_arrow/series_list.py @@ -6,19 +6,19 @@ import pyarrow.compute as pc from narwhals._arrow.utils import ArrowSeriesNamespace +from narwhals._compliant.any_namespace import ListNamespace from narwhals._utils import not_implemented if TYPE_CHECKING: from narwhals._arrow.series import ArrowSeries -class ArrowSeriesListNamespace(ArrowSeriesNamespace): +class ArrowSeriesListNamespace(ArrowSeriesNamespace, ListNamespace["ArrowSeries"]): def len(self) -> ArrowSeries: return self.with_native(pc.list_value_length(self.native).cast(pa.uint32())) - unique = not_implemented() - - contains = not_implemented() - def get(self, index: int) -> ArrowSeries: return self.with_native(pc.list_element(self.native, index)) + + unique = not_implemented() + contains = not_implemented() diff --git a/narwhals/_arrow/series_str.py b/narwhals/_arrow/series_str.py index 1df8fc31a0..4b3fe0ee1d 100644 --- a/narwhals/_arrow/series_str.py +++ b/narwhals/_arrow/series_str.py @@ -7,13 +7,14 @@ import pyarrow.compute as pc from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format +from narwhals._compliant.any_namespace import StringNamespace if TYPE_CHECKING: from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import Incomplete -class ArrowSeriesStringNamespace(ArrowSeriesNamespace): +class ArrowSeriesStringNamespace(ArrowSeriesNamespace, StringNamespace["ArrowSeries"]): def len_chars(self) -> ArrowSeries: return self.with_native(pc.utf8_length(self.native)) diff --git a/narwhals/_arrow/series_struct.py b/narwhals/_arrow/series_struct.py index be5aa4b393..906725ba7b 100644 --- a/narwhals/_arrow/series_struct.py +++ b/narwhals/_arrow/series_struct.py @@ -5,11 +5,12 @@ import pyarrow.compute as pc from narwhals._arrow.utils import ArrowSeriesNamespace +from narwhals._compliant.any_namespace import StructNamespace if TYPE_CHECKING: from narwhals._arrow.series import ArrowSeries -class ArrowSeriesStructNamespace(ArrowSeriesNamespace): +class ArrowSeriesStructNamespace(ArrowSeriesNamespace, StructNamespace["ArrowSeries"]): def field(self, name: str) -> ArrowSeries: return self.with_native(pc.struct_field(self.native, name)).alias(name) diff --git a/narwhals/_compliant/any_namespace.py b/narwhals/_compliant/any_namespace.py index e5f6f0c869..54df3160fd 100644 --- a/narwhals/_compliant/any_namespace.py +++ b/narwhals/_compliant/any_namespace.py @@ -2,13 +2,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING, ClassVar, Protocol from narwhals._utils import CompliantT_co, _StoresCompliant if TYPE_CHECKING: from typing import Callable + from narwhals._compliant.typing import Accessor from narwhals.typing import NonNestedLiteral, TimeUnit __all__ = [ @@ -16,16 +17,25 @@ "DateTimeNamespace", "ListNamespace", "NameNamespace", + "NamespaceAccessor", "StringNamespace", "StructNamespace", ] -class CatNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): +class NamespaceAccessor(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): + _accessor: ClassVar[Accessor] + + +class CatNamespace(NamespaceAccessor[CompliantT_co], Protocol[CompliantT_co]): + _accessor: ClassVar[Accessor] = "cat" + def get_categories(self) -> CompliantT_co: ... class DateTimeNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): + _accessor: ClassVar[Accessor] = "dt" + def to_string(self, format: str) -> CompliantT_co: ... def replace_time_zone(self, time_zone: str | None) -> CompliantT_co: ... def convert_time_zone(self, time_zone: str) -> CompliantT_co: ... @@ -52,15 +62,17 @@ def offset_by(self, by: str) -> CompliantT_co: ... class ListNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): - def get(self, index: int) -> CompliantT_co: ... + _accessor: ClassVar[Accessor] = "list" + def get(self, index: int) -> CompliantT_co: ... def len(self) -> CompliantT_co: ... - def unique(self) -> CompliantT_co: ... def contains(self, item: NonNestedLiteral) -> CompliantT_co: ... class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): + _accessor: ClassVar[Accessor] = "name" + def keep(self) -> CompliantT_co: ... def map(self, function: Callable[[str], str]) -> CompliantT_co: ... def prefix(self, prefix: str) -> CompliantT_co: ... @@ -70,6 +82,8 @@ def to_uppercase(self) -> CompliantT_co: ... class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): + _accessor: ClassVar[Accessor] = "str" + def len_chars(self) -> CompliantT_co: ... def replace( self, pattern: str, value: str, *, literal: bool, n: int @@ -91,4 +105,6 @@ def zfill(self, width: int) -> CompliantT_co: ... class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): + _accessor: ClassVar[Accessor] = "struct" + def field(self, name: str) -> CompliantT_co: ... diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 6f7f45d548..8bdb1bc6ad 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -204,3 +204,8 @@ class ScalarKwargs(TypedDict, total=False): - https://github.com/narwhals-dev/narwhals/issues/2526 - https://github.com/narwhals-dev/narwhals/issues/2660 """ + +Accessor: TypeAlias = Literal[ + "arr", "cat", "dt", "list", "meta", "name", "str", "bin", "struct" +] +"""`{Expr,Series}` method namespace accessor name.""" diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 2c23a741ee..8de0ad96c4 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal import polars as pl @@ -22,6 +22,7 @@ from typing_extensions import Self + from narwhals._compliant.typing import Accessor from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._polars.dataframe import Method from narwhals._polars.namespace import PolarsNamespace @@ -400,17 +401,11 @@ class PolarsExprDateTimeNamespace( class PolarsExprStringNamespace( PolarsExprNamespace, PolarsStringNamespace[PolarsExpr, pl.Expr] ): + @requires.backend_version((0, 20, 5)) def zfill(self, width: int) -> PolarsExpr: backend_version = self.compliant._backend_version native_result = self.native.str.zfill(width) - if backend_version < (0, 20, 5): # pragma: no cover - # Reason: - # `TypeError: argument 'length': 'Expr' object cannot be interpreted as an integer` - # in `native_expr.str.slice(1, length)` - msg = "`zfill` is only available in 'polars>=0.20.5', found version '0.20.4'." - raise NotImplementedError(msg) - if backend_version <= (1, 30, 0): length = self.native.str.len_chars() less_than_width = length < width @@ -435,7 +430,7 @@ class PolarsExprCatNamespace( class PolarsExprNameNamespace(PolarsExprNamespace): - _accessor = "name" + _accessor: ClassVar[Accessor] = "name" keep: Method[PolarsExpr] map: Method[PolarsExpr] prefix: Method[PolarsExpr] diff --git a/narwhals/_polars/typing.py b/narwhals/_polars/typing.py index a7dd697d56..b38aab8b98 100644 --- a/narwhals/_polars/typing.py +++ b/narwhals/_polars/typing.py @@ -1,25 +1,10 @@ from __future__ import annotations # pragma: no cover -from typing import ( - TYPE_CHECKING, # pragma: no cover - Union, # pragma: no cover -) +from typing import TYPE_CHECKING # pragma: no cover if TYPE_CHECKING: - import sys - from typing import Literal, TypeVar - - if sys.version_info >= (3, 10): - from typing import TypeAlias - else: - from typing_extensions import TypeAlias + from typing import TypeVar from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame - from narwhals._polars.expr import PolarsExpr - from narwhals._polars.series import PolarsSeries - IntoPolarsExpr: TypeAlias = Union[PolarsExpr, PolarsSeries] FrameT = TypeVar("FrameT", PolarsDataFrame, PolarsLazyFrame) - NativeAccessor: TypeAlias = Literal[ - "arr", "cat", "dt", "list", "meta", "name", "str", "bin", "struct" - ] diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 8a365ddd00..279ff6ae85 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -30,10 +30,10 @@ from typing_extensions import TypeIs + from narwhals._compliant.typing import Accessor from narwhals._polars.dataframe import Method from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries - from narwhals._polars.typing import NativeAccessor from narwhals.dtypes import DType from narwhals.typing import IntoDType @@ -261,7 +261,7 @@ class PolarsAnyNamespace( _StoresNative[NativeT_co], Protocol[CompliantT_co, NativeT_co], ): - _accessor: ClassVar[NativeAccessor] + _accessor: ClassVar[Accessor] def __getattr__(self, attr: str) -> Callable[..., CompliantT_co]: def func(*args: Any, **kwargs: Any) -> CompliantT_co: @@ -273,7 +273,7 @@ def func(*args: Any, **kwargs: Any) -> CompliantT_co: class PolarsDateTimeNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]): - _accessor: ClassVar[NativeAccessor] = "dt" + _accessor: ClassVar[Accessor] = "dt" def truncate(self, every: str) -> CompliantT: # Ensure consistent error message is raised. @@ -309,7 +309,7 @@ def offset_by(self, by: str) -> CompliantT: class PolarsStringNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]): - _accessor: ClassVar[NativeAccessor] = "str" + _accessor: ClassVar[Accessor] = "str" # NOTE: Use `abstractmethod` if we have defs to implement, but also `Method` usage @abc.abstractmethod @@ -331,12 +331,12 @@ def zfill(self, width: int) -> CompliantT: ... class PolarsCatNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]): - _accessor: ClassVar[NativeAccessor] = "cat" + _accessor: ClassVar[Accessor] = "cat" get_categories: Method[CompliantT] class PolarsListNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]): - _accessor: ClassVar[NativeAccessor] = "list" + _accessor: ClassVar[Accessor] = "list" @abc.abstractmethod def len(self) -> CompliantT: ... @@ -347,5 +347,5 @@ def len(self) -> CompliantT: ... class PolarsStructNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]): - _accessor: ClassVar[NativeAccessor] = "struct" + _accessor: ClassVar[Accessor] = "struct" field: Method[CompliantT] diff --git a/narwhals/_utils.py b/narwhals/_utils.py index 8f0ed9b4f8..33375c9a2f 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -15,6 +15,7 @@ TYPE_CHECKING, Any, Callable, + Final, Generic, Literal, Protocol, @@ -72,7 +73,13 @@ CompliantSeriesT, NativeSeriesT_co, ) - from narwhals._compliant.typing import EvalNames, NativeDataFrameT, NativeLazyFrameT + from narwhals._compliant.any_namespace import NamespaceAccessor + from narwhals._compliant.typing import ( + Accessor, + EvalNames, + NativeDataFrameT, + NativeLazyFrameT, + ) from narwhals._namespace import ( Namespace, _NativeArrow, @@ -163,8 +170,9 @@ def columns(self) -> Sequence[str]: ... _T = TypeVar("_T") NativeT_co = TypeVar("NativeT_co", covariant=True) CompliantT_co = TypeVar("CompliantT_co", covariant=True) -_ContextT = TypeVar("_ContextT", bound="_FullContext") -_Method: TypeAlias = "Callable[Concatenate[_ContextT, P], R]" +_IntoContext: TypeAlias = "_FullContext | NamespaceAccessor[_FullContext]" +_IntoContextT = TypeVar("_IntoContextT", bound=_IntoContext) +_Method: TypeAlias = "Callable[Concatenate[_IntoContextT, P], R]" _Constructor: TypeAlias = "Callable[Concatenate[_T, P], R2]" @@ -223,12 +231,11 @@ class _LimitedContext(_StoresImplementation, _StoresVersion, Protocol): """ -class _FullContext(_StoresBackendVersion, _LimitedContext, Protocol): - """Provides 3 attributes. +class _FullContext(_StoresImplementation, _StoresBackendVersion, Protocol): + """Provides 2 attributes. - `_implementation` - `_backend_version` - - `_version` """ @@ -1631,9 +1638,11 @@ def fn(_frame: Any, /) -> Sequence[str]: return fn +_SENTINEL: Final = object() + + def _hasattr_static(obj: Any, attr: str) -> bool: - sentinel = object() - return getattr_static(obj, attr, sentinel) is not sentinel + return getattr_static(obj, attr, _SENTINEL) is not _SENTINEL def is_compliant_dataframe( @@ -1671,6 +1680,16 @@ def is_compliant_expr( return hasattr(obj, "__narwhals_expr__") +def _is_namespace_accessor(obj: _IntoContext) -> TypeIs[NamespaceAccessor[_FullContext]]: + # NOTE: Only `compliant` has false positives **internally** + # - https://github.com/narwhals-dev/narwhals/blob/cc69bac35eb8c81a1106969c49bfba9fd569b856/narwhals/_compliant/group_by.py#L44-L49 + # - https://github.com/narwhals-dev/narwhals/blob/cc69bac35eb8c81a1106969c49bfba9fd569b856/narwhals/_namespace.py#L166-L168 + # NOTE: Only `_accessor` has false positives **upstream** + # - https://github.com/pandas-dev/pandas/blob/e209a35403f8835bbcff97636b83d2fc39b51e68/pandas/core/accessor.py#L200-L233 + # - https://github.com/pola-rs/polars/blob/a60c5019f7b694c97009ef9208d25aaa4cc1d8a6/py-polars/polars/api.py#L29-L42 + return _hasattr_static(obj, "compliant") and _hasattr_static(obj, "_accessor") + + def is_eager_allowed(impl: Implementation, /) -> TypeIs[_EagerAllowedImpl]: """Return True if `impl` allows eager operations.""" return impl in { @@ -1906,6 +1925,11 @@ class requires: # noqa: N801 _min_version: tuple[int, ...] _hint: str + _wrapped_name: str + """(Unqualified) decorated method name. + + When used in a namespace accessor, it will be prefixed by the property name. + """ @classmethod def backend_version(cls, minimum: tuple[int, ...], /, hint: str = "") -> Self: @@ -1924,23 +1948,37 @@ def backend_version(cls, minimum: tuple[int, ...], /, hint: str = "") -> Self: def _unparse_version(backend_version: tuple[int, ...], /) -> str: return ".".join(f"{d}" for d in backend_version) - def _ensure_version(self, instance: _FullContext, /) -> None: - if instance._backend_version >= self._min_version: + def _qualify_accessor_name(self, prefix: Accessor, /) -> None: + # NOTE: Should only need to do this once per class (the first time the method is called) + if "." not in self._wrapped_name: + self._wrapped_name = f"{prefix}.{self._wrapped_name}" + + def _unwrap_context(self, instance: _IntoContext, /) -> tuple[tuple[int, ...], str]: + if _is_namespace_accessor(instance): + self._qualify_accessor_name(instance._accessor) + compliant = instance.compliant + else: + compliant = instance + return compliant._backend_version, str(compliant._implementation) + + def _ensure_version(self, instance: _IntoContext, /) -> None: + version, backend = self._unwrap_context(instance) + if version >= self._min_version: return - method = self._wrapped_name - backend = instance._implementation minimum = self._unparse_version(self._min_version) - found = self._unparse_version(instance._backend_version) - msg = f"`{method}` is only available in '{backend}>={minimum}', found version {found!r}." + found = self._unparse_version(version) + msg = f"`{self._wrapped_name}` is only available in '{backend}>={minimum}', found version {found!r}." if self._hint: msg = f"{msg}\n{self._hint}" raise NotImplementedError(msg) - def __call__(self, fn: _Method[_ContextT, P, R], /) -> _Method[_ContextT, P, R]: + def __call__( + self, fn: _Method[_IntoContextT, P, R], / + ) -> _Method[_IntoContextT, P, R]: self._wrapped_name = fn.__name__ @wraps(fn) - def wrapper(instance: _ContextT, *args: P.args, **kwds: P.kwargs) -> R: + def wrapper(instance: _IntoContextT, *args: P.args, **kwds: P.kwargs) -> R: self._ensure_version(instance) return fn(instance, *args, **kwds) diff --git a/tests/utils_test.py b/tests/utils_test.py index c38d5e8475..9d7a0247ba 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -4,7 +4,7 @@ import string from dataclasses import dataclass from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Protocol, cast +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, cast import hypothesis.strategies as st import pandas as pd @@ -16,7 +16,6 @@ import narwhals as nw from narwhals._utils import ( Implementation, - Version, _DeferredIterable, check_columns_exist, deprecate_native_namespace, @@ -31,6 +30,7 @@ from typing_extensions import Self + from narwhals._compliant.typing import Accessor from narwhals._utils import _SupportsVersion from narwhals.series import Series @@ -495,9 +495,25 @@ def func3( def test_requires() -> None: + class SomeAccesssor: + _accessor: ClassVar[Accessor] = "str" + + def __init__(self, compliant: ProbablyCompliant) -> None: + self._compliant: ProbablyCompliant = compliant + + @property + def compliant(self) -> ProbablyCompliant: + return self._compliant + + def waddle(self) -> str: + return f"waddle<{self.compliant.native}>waddle" + + @requires.backend_version((1, 8, 0)) + def nope(self) -> str: + return "nooooooooooooooooooooooooooo" + class ProbablyCompliant: _implementation: Implementation = Implementation.POLARS - _version: Version = Version.MAIN def __init__(self, native_obj: str, backend_version: tuple[int, ...]) -> None: self._native_obj: str = native_obj @@ -519,6 +535,10 @@ def concat(self, *strings: str, separator: str = "") -> str: def repeat(self, n: int) -> str: return self.native * n + @property + def str(self) -> SomeAccesssor: + return SomeAccesssor(self) + v_05 = ProbablyCompliant("123", (0, 5)) v_201 = ProbablyCompliant("123", (2, 0, 1)) v_300 = ProbablyCompliant("123", (3, 0, 0)) @@ -547,6 +567,15 @@ def repeat(self, n: int) -> str: with pytest.raises(NotImplementedError, match=pattern): v_05.concat("never") + waddled = v_201.str.waddle() + assert waddled == "waddle<123>waddle" + assert v_05.str.waddle() == waddled + noped = v_201.str.nope() + assert noped == "nooooooooooooooooooooooooooo" + match = r"`str\.nope`.+\'polars>=1.8.0\'.+found.+\'0.5\'" + with pytest.raises(NotImplementedError, match=match): + v_05.str.nope() + def test_deferred_iterable() -> None: def to_upper(it: Iterable[str]) -> Callable[[], Iterator[str]]: