diff --git a/narwhals/_utils.py b/narwhals/_utils.py index 509239da55..929c81a220 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -73,11 +73,7 @@ CompliantSeriesT, NativeSeriesT_co, ) - from narwhals._compliant.typing import ( - EvalNames, - NativeDataFrameT, - NativeLazyFrameT, - ) + from narwhals._compliant.typing import EvalNames, NativeDataFrameT, NativeLazyFrameT from narwhals._namespace import ( Namespace, _NativeArrow, @@ -2067,6 +2063,7 @@ def deep_getattr(obj: Any, name_1: str, *nested: str) -> Any: """Perform a nested attribute lookup on `obj`.""" return deep_attrgetter(name_1, *nested)(obj) + class Compliant( _StoresNative[NativeT_co], _StoresImplementation, Protocol[NativeT_co] ): ... diff --git a/narwhals/plugins.py b/narwhals/plugins.py new file mode 100644 index 0000000000..c994509a02 --- /dev/null +++ b/narwhals/plugins.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import sys +from functools import cache +from typing import TYPE_CHECKING, Any, Protocol, cast + +from narwhals._compliant import CompliantNamespace +from narwhals._typing_compat import TypeVar + +if TYPE_CHECKING: + from collections.abc import Iterator + from importlib.metadata import EntryPoints + + from typing_extensions import LiteralString, TypeAlias + + from narwhals._compliant.typing import ( + CompliantDataFrameAny, + CompliantFrameAny, + CompliantLazyFrameAny, + CompliantSeriesAny, + ) + from narwhals.utils import Version + + +__all__ = ["Plugin", "from_native"] + +CompliantAny: TypeAlias = ( + "CompliantDataFrameAny | CompliantLazyFrameAny | CompliantSeriesAny" +) +"""A statically-unknown, Compliant object originating from a plugin.""" + +FrameT = TypeVar( + "FrameT", + bound="CompliantFrameAny", + default="CompliantDataFrameAny | CompliantLazyFrameAny", +) +FromNativeR_co = TypeVar( + "FromNativeR_co", bound=CompliantAny, covariant=True, default=CompliantAny +) + + +@cache +def _discover_entrypoints() -> EntryPoints: + from importlib.metadata import entry_points as eps + + group = "narwhals.plugins" + if sys.version_info < (3, 10): + return cast("EntryPoints", eps().get(group, ())) + return eps(group=group) + + +class PluginNamespace(CompliantNamespace[FrameT, Any], Protocol[FrameT, FromNativeR_co]): + def from_native(self, data: Any, /) -> FromNativeR_co: ... + + +class Plugin(Protocol[FrameT, FromNativeR_co]): + NATIVE_PACKAGE: LiteralString + + def __narwhals_namespace__( + self, version: Version + ) -> PluginNamespace[FrameT, FromNativeR_co]: ... + def is_native(self, native_object: object, /) -> bool: ... + + +@cache +def _might_be(cls: type, type_: str) -> bool: + try: + return any(type_ in o.__module__.split(".") for o in cls.mro()) + except TypeError: + return False + + +def _is_native_plugin(native_object: Any, plugin: Plugin) -> bool: + pkg = plugin.NATIVE_PACKAGE + return ( + sys.modules.get(pkg) is not None + and _might_be(type(native_object), pkg) # type: ignore[arg-type] + and plugin.is_native(native_object) + ) + + +def _iter_from_native(native_object: Any, version: Version) -> Iterator[CompliantAny]: + for entry_point in _discover_entrypoints(): + plugin: Plugin = entry_point.load() + if _is_native_plugin(native_object, plugin): + compliant_namespace = plugin.__narwhals_namespace__(version=version) + yield compliant_namespace.from_native(native_object) + + +def from_native(native_object: Any, version: Version) -> CompliantAny | None: + """Attempt to convert `native_object` to a Compliant object, using any available plugin(s). + + Arguments: + native_object: Raw object from user. + version: Narwhals API version. + + Returns: + If the following conditions are met + - at least 1 plugin is installed + - at least 1 installed plugin supports `type(native_object)` + + Then for the **first matching plugin**, the result of the call below. + This *should* be an object accepted by a Narwhals Dataframe, Lazyframe, or Series: + + plugin: Plugin + plugin.__narwhals_namespace__(version).from_native(native_object) + + In all other cases, `None` is returned instead. + """ + return next(_iter_from_native(native_object, version), None) diff --git a/narwhals/plugins/__init__.py b/narwhals/plugins/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/narwhals/plugins/_utils.py b/narwhals/plugins/_utils.py deleted file mode 100644 index bf467d69b3..0000000000 --- a/narwhals/plugins/_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import sys -from functools import cache -from typing import TYPE_CHECKING, Any, Protocol, cast - -if TYPE_CHECKING: - from importlib.metadata import EntryPoints - - from typing_extensions import LiteralString - - from narwhals._compliant.typing import CompliantNamespaceAny - from narwhals.utils import Version - - -@cache -def discover_entrypoints() -> EntryPoints: - from importlib.metadata import entry_points as eps - - group = "narwhals.plugins" - if sys.version_info < (3, 10): - return cast("EntryPoints", eps().get(group, ())) - return eps(group=group) - - -class Plugin(Protocol): - NATIVE_PACKAGE: LiteralString - - def __narwhals_namespace__(self, version: Version) -> CompliantNamespaceAny: ... - def is_native(self, native_object: object, /) -> bool: ... - - -@cache -def _might_be(cls: type, type_: str) -> bool: - try: - return any(type_ in o.__module__.split(".") for o in cls.mro()) - except TypeError: - return False - - -def _is_native_plugin(native_object: Any, plugin: Plugin) -> bool: - pkg = plugin.NATIVE_PACKAGE - return ( - sys.modules.get(pkg) is not None - and _might_be(type(native_object), pkg) - and plugin.is_native(native_object) - ) diff --git a/narwhals/translate.py b/narwhals/translate.py index 3e5a3f5f66..33e84cd365 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -5,6 +5,7 @@ from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, overload +from narwhals import plugins from narwhals._constants import EPOCH, MS_PER_SECOND from narwhals._namespace import ( is_native_arrow, @@ -35,7 +36,6 @@ is_pyarrow_scalar, is_pyarrow_table, ) -from narwhals.plugins._utils import _is_native_plugin, discover_entrypoints if TYPE_CHECKING: from narwhals.dataframe import DataFrame, LazyFrame @@ -564,20 +564,17 @@ def _from_native_impl( # noqa: C901, PLR0911, PLR0912, PLR0915 raise TypeError(msg) return Version.V1.dataframe(InterchangeFrame(native_object), level="interchange") - for entry_point in discover_entrypoints(): - plugin = entry_point.load() - if _is_native_plugin(native_object, plugin): - compliant_namespace = plugin.__narwhals_namespace__(version=version) - compliant_object = compliant_namespace.from_native(native_object) - return _translate_if_compliant( - compliant_object, - pass_through=pass_through, - eager_only=eager_only, - eager_or_interchange_only=eager_or_interchange_only, - series_only=series_only, - allow_series=allow_series, - version=version, - ) + compliant_object = plugins.from_native(native_object, version) + if compliant_object is not None: + return _translate_if_compliant( + compliant_object, + pass_through=pass_through, + eager_only=eager_only, + eager_or_interchange_only=eager_or_interchange_only, + series_only=series_only, + allow_series=allow_series, + version=version, + ) if not pass_through: msg = f"Expected pandas-like dataframe, Polars dataframe, or Polars lazyframe, got: {type(native_object)}"