diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45518534d0..ba81cc144f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -83,12 +83,12 @@ repos: name: don't import from narwhals.dtypes (use `Version.dtypes` instead) entry: | (?x) - import\ narwhals.dtypes| - from\ narwhals\ import\ dtypes| - from\ narwhals.dtypes\ import\ [^D_]+| - import\ narwhals.stable.v1.dtypes| - from\ narwhals.stable\.v.\ import\ dtypes| - from\ narwhals.stable\.v.\.dtypes\ import + import\ narwhals(\.stable\.v\d)?\.dtypes| + from\ narwhals(\.stable\.v\d)?\ import\ dtypes| + ^from\ narwhals(\.stable\.v\d)?\.dtypes\ import + \ (DType,\ )? + ((Datetime|Duration|Enum)(,\ )?)+ + ((,\ )?DType)? language: pygrep files: ^narwhals/ exclude: | diff --git a/docs/api-reference/testing.md b/docs/api-reference/testing.md new file mode 100644 index 0000000000..e73fb4d447 --- /dev/null +++ b/docs/api-reference/testing.md @@ -0,0 +1,7 @@ +# `narwhals.testing` + +::: narwhals.testing + handler: python + options: + members: + - assert_series_equal diff --git a/mkdocs.yml b/mkdocs.yml index 73c7888f21..c02f227f87 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -69,6 +69,7 @@ nav: - api-reference/dtypes.md - api-reference/exceptions.md - api-reference/selectors.md + - api-reference/testing.md - api-reference/typing.md - api-reference/utils.md - This: this.md diff --git a/narwhals/testing/__init__.py b/narwhals/testing/__init__.py new file mode 100644 index 0000000000..6bbab67b64 --- /dev/null +++ b/narwhals/testing/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from narwhals.testing.asserts.series import assert_series_equal + +__all__ = ("assert_series_equal",) diff --git a/narwhals/testing/asserts/__init__.py b/narwhals/testing/asserts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/narwhals/testing/asserts/series.py b/narwhals/testing/asserts/series.py new file mode 100644 index 0000000000..1408da0309 --- /dev/null +++ b/narwhals/testing/asserts/series.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Any, Callable + +from narwhals._utils import qualified_type_name, zip_strict +from narwhals.dependencies import is_narwhals_series +from narwhals.dtypes import Array, Boolean, Categorical, List, String, Struct +from narwhals.functions import new_series +from narwhals.testing.asserts.utils import raise_series_assertion_error + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals.series import Series + from narwhals.typing import IntoSeriesT, SeriesT + + CheckFn: TypeAlias = Callable[[Series[Any], Series[Any]], None] + + +def assert_series_equal( + left: Series[IntoSeriesT], + right: Series[IntoSeriesT], + *, + check_dtypes: bool = True, + check_names: bool = True, + check_order: bool = True, + check_exact: bool = False, + rel_tol: float = 1e-05, + abs_tol: float = 1e-08, + categorical_as_str: bool = False, +) -> None: + """Assert that the left and right Series are equal. + + Raises a detailed `AssertionError` if the Series differ. + This function is intended for use in unit tests. + + Arguments: + left: The first Series to compare. + right: The second Series to compare. + check_dtypes: Requires data types to match. + check_names: Requires names to match. + check_order: Requires elements to appear in the same order. + check_exact: Requires float values to match exactly. If set to `False`, values are + considered equal when within tolerance of each other (see `rel_tol` and + `abs_tol`). Only affects columns with a Float data type. + rel_tol: Relative tolerance for inexact checking, given as a fraction of the + values in `right`. + abs_tol: Absolute tolerance for inexact checking. + categorical_as_str: Cast categorical columns to string before comparing. + Enabling this helps compare columns that do not share the same string cache. + + Examples: + >>> import pandas as pd + >>> import narwhals as nw + >>> from narwhals.testing import assert_series_equal + >>> s1 = nw.from_native(pd.Series([1, 2, 3]), series_only=True) + >>> s2 = nw.from_native(pd.Series([1, 5, 3]), series_only=True) + >>> assert_series_equal(s1, s2) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + AssertionError: Series are different (exact value mismatch) + [left]: + ┌───────────────┐ + |Narwhals Series| + |---------------| + | 0 1 | + | 1 2 | + | 2 3 | + | dtype: int64 | + └───────────────┘ + [right]: + ┌───────────────┐ + |Narwhals Series| + |---------------| + | 0 1 | + | 1 5 | + | 2 3 | + | dtype: int64 | + └───────────────┘ + """ + __tracebackhide__ = True + + if any(not is_narwhals_series(obj) for obj in (left, right)): + msg = ( + "Expected `narwhals.Series` instance, found:\n" + f"[left]: {qualified_type_name(type(left))}\n" + f"[right]: {qualified_type_name(type(right))}\n\n" + "Hint: Use `nw.from_native(obj, series_only=True) to convert each native " + "object into a `narwhals.Series` first." + ) + raise TypeError(msg) + + _check_metadata(left, right, check_dtypes=check_dtypes, check_names=check_names) + + if not check_order: + if left.dtype.is_nested(): + msg = "`check_order=False` is not supported (yet) with nested data type." + raise NotImplementedError(msg) + left, right = left.sort(), right.sort() + + left_vals, right_vals = _check_null_values(left, right) + + if check_exact or not left.dtype.is_float(): + _check_exact_values( + left_vals, + right_vals, + check_dtypes=check_dtypes, + check_exact=check_exact, + rel_tol=rel_tol, + abs_tol=abs_tol, + categorical_as_str=categorical_as_str, + ) + else: + _check_approximate_values(left_vals, right_vals, rel_tol=rel_tol, abs_tol=abs_tol) + + +def _check_metadata( + left: SeriesT, right: SeriesT, *, check_dtypes: bool, check_names: bool +) -> None: + """Check metadata information: implementation, length, dtype, and names.""" + left_impl, right_impl = left.implementation, right.implementation + if left_impl != right_impl: + raise_series_assertion_error("implementation mismatch", left_impl, right_impl) + + left_len, right_len = len(left), len(right) + if left_len != right_len: + raise_series_assertion_error("length mismatch", left_len, right_len) + + left_dtype, right_dtype = left.dtype, right.dtype + if check_dtypes and left_dtype != right_dtype: + raise_series_assertion_error("dtype mismatch", left_dtype, right_dtype) + + left_name, right_name = left.name, right.name + if check_names and left_name != right_name: + raise_series_assertion_error("name mismatch", left_name, right_name) + + +def _check_null_values(left: SeriesT, right: SeriesT) -> tuple[SeriesT, SeriesT]: + """Check null value consistency and return non-null values.""" + left_null_count, right_null_count = left.null_count(), right.null_count() + left_null_mask, right_null_mask = left.is_null(), right.is_null() + + if left_null_count != right_null_count or (left_null_mask != right_null_mask).any(): + raise_series_assertion_error( + "null value mismatch", left_null_count, right_null_count + ) + + return left.filter(~left_null_mask), right.filter(~right_null_mask) + + +def _check_exact_values( + left: SeriesT, + right: SeriesT, + *, + check_dtypes: bool, + check_exact: bool, + rel_tol: float, + abs_tol: float, + categorical_as_str: bool, +) -> None: + """Check exact value equality for various data types.""" + left_impl = left.implementation + left_dtype, right_dtype = left.dtype, right.dtype + + is_not_equal_mask: Series[Any] + if left_dtype.is_numeric(): + # For _all_ numeric dtypes, we can use `is_close` with 0-tolerances to handle + # inf and nan values out of the box. + is_not_equal_mask = ~left.is_close(right, rel_tol=0, abs_tol=0, nans_equal=True) + elif ( + isinstance(left_dtype, (Array, List)) and isinstance(right_dtype, (Array, List)) + ) and left_dtype == right_dtype: + check_fn = partial( + assert_series_equal, + check_dtypes=check_dtypes, + check_names=False, + check_order=True, + check_exact=check_exact, + rel_tol=rel_tol, + abs_tol=abs_tol, + categorical_as_str=categorical_as_str, + ) + _check_list_like(left, right, left_dtype, right_dtype, check_fn=check_fn) + # If `_check_list_like` didn't raise, then every nested element is equal + is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl) + elif isinstance(left_dtype, Struct) and isinstance(right_dtype, Struct): + check_fn = partial( + assert_series_equal, + check_dtypes=True, + check_names=True, + check_order=True, + check_exact=check_exact, + rel_tol=rel_tol, + abs_tol=abs_tol, + categorical_as_str=categorical_as_str, + ) + _check_struct(left, right, left_dtype, right_dtype, check_fn=check_fn) + # If `_check_struct` didn't raise, then every nested element is equal + is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl) + elif isinstance(left_dtype, Categorical) and isinstance(right_dtype, Categorical): + # If `_check_categorical` didn't raise, then the categories sources/encodings are + # the same, and we can use equality + _not_equal = _check_categorical( + left, right, categorical_as_str=categorical_as_str + ) + is_not_equal_mask = new_series( + "", [_not_equal], dtype=Boolean(), backend=left_impl + ) + else: + is_not_equal_mask = left != right + + if is_not_equal_mask.any(): + raise_series_assertion_error("exact value mismatch", left, right) + + +def _check_approximate_values( + left: SeriesT, right: SeriesT, *, rel_tol: float, abs_tol: float +) -> None: + """Check approximate value equality with tolerance.""" + is_not_close_mask = ~left.is_close( + right, rel_tol=rel_tol, abs_tol=abs_tol, nans_equal=True + ) + + if is_not_close_mask.any(): + raise_series_assertion_error( + "values not within tolerance", + left.filter(is_not_close_mask), + right.filter(is_not_close_mask), + ) + + +def _check_list_like( + left_vals: SeriesT, + right_vals: SeriesT, + left_dtype: List | Array, + right_dtype: List | Array, + check_fn: CheckFn, +) -> None: + # Check row by row after transforming each array/list into a new series. + # Notice that order within the array/list must be the same, regardless of + # `check_order` value at the top level. + impl = left_vals.implementation + try: + for left_val, right_val in zip_strict(left_vals, right_vals): + check_fn( + new_series("", values=left_val, dtype=left_dtype.inner, backend=impl), + new_series("", values=right_val, dtype=right_dtype.inner, backend=impl), + ) + except AssertionError: + raise_series_assertion_error("nested value mismatch", left_vals, right_vals) + + +def _check_struct( + left_vals: SeriesT, + right_vals: SeriesT, + left_dtype: Struct, + right_dtype: Struct, + check_fn: CheckFn, +) -> None: + # Check field by field as a separate column. + # Notice that for struct's polars raises if: + # * field names are different but values are equal + # * dtype differs, regardless of `check_dtypes=False` + # * order applies only at top level + try: + for left_field, right_field in zip_strict(left_dtype.fields, right_dtype.fields): + check_fn( + left_vals.struct.field(left_field.name), + right_vals.struct.field(right_field.name), + ) + except AssertionError: + raise_series_assertion_error("exact value mismatch", left_vals, right_vals) + + +def _check_categorical( + left_vals: SeriesT, right_vals: SeriesT, *, categorical_as_str: bool +) -> bool: + """Try to compare if any element of categorical series' differ. + + Inability to compare means that the encoding is different, and an exception is raised. + """ + if categorical_as_str: + left_vals, right_vals = left_vals.cast(String()), right_vals.cast(String()) + + try: + return (left_vals != right_vals).any() + except Exception as exc: + msg = "Cannot compare categoricals coming from different sources." + # TODO(FBruzzesi): Improve error message? + raise AssertionError(msg) from exc diff --git a/narwhals/testing/asserts/utils.py b/narwhals/testing/asserts/utils.py new file mode 100644 index 0000000000..13720c2a35 --- /dev/null +++ b/narwhals/testing/asserts/utils.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +from narwhals.dependencies import is_narwhals_series + +if TYPE_CHECKING: + from typing_extensions import Never, TypeAlias + +# NOTE: This alias is created to facilitate autocomplete. Feel free to extend it as +# you please when adding a new feature. +# See: https://github.com/narwhals-dev/narwhals/pull/2983#discussion_r2337548736 +SeriesDetail: TypeAlias = Literal[ + "implementation mismatch", + "length mismatch", + "dtype mismatch", + "name mismatch", + "null value mismatch", + "exact value mismatch", + "values not within tolerance", + "nested value mismatch", +] + + +def raise_assertion_error( + objects: str, detail: str, left: Any, right: Any, *, cause: Exception | None = None +) -> Never: + """Raise a detailed assertion error.""" + __tracebackhide__ = True + + trailing_left = "\n" if is_narwhals_series(left) else " " + trailing_right = "\n" if is_narwhals_series(right) else " " + + msg = ( + f"{objects} are different ({detail})\n" + f"[left]:{trailing_left}{left}\n" + f"[right]:{trailing_right}{right}" + ) + raise AssertionError(msg) from cause + + +def raise_series_assertion_error( + detail: SeriesDetail, left: Any, right: Any, *, cause: Exception | None = None +) -> Never: + raise_assertion_error("Series", detail, left, right, cause=cause) diff --git a/tests/testing/__init__.py b/tests/testing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/testing/assert_series_equal_test.py b/tests/testing/assert_series_equal_test.py new file mode 100644 index 0000000000..475792eb42 --- /dev/null +++ b/tests/testing/assert_series_equal_test.py @@ -0,0 +1,387 @@ +from __future__ import annotations + +import re +from contextlib import AbstractContextManager, nullcontext as does_not_raise +from typing import TYPE_CHECKING, Any, Callable + +import pytest + +import narwhals as nw +from narwhals.testing import assert_series_equal +from tests.utils import PANDAS_VERSION, POLARS_VERSION, PYARROW_VERSION + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals.typing import IntoSchema, IntoSeriesT + from tests.conftest import Data + from tests.utils import ConstructorEager + + SetupFn: TypeAlias = Callable[[nw.Series[Any]], tuple[nw.Series[Any], nw.Series[Any]]] + + +def _assertion_error(detail: str) -> pytest.RaisesExc: + return pytest.raises( + AssertionError, match=re.escape(f"Series are different ({detail})") + ) + + +def series_from_native(native: IntoSeriesT) -> nw.Series[IntoSeriesT]: + return nw.from_native(native, series_only=True) + + +def test_self_equal( + constructor_eager: ConstructorEager, data: Data, schema: IntoSchema +) -> None: + """Test that a series is equal to itself, including nested dtypes with nulls.""" + if "pandas" in str(constructor_eager) and PANDAS_VERSION < (2, 2): # pragma: no cover + reason = "Pandas too old for nested dtypes" + pytest.skip(reason=reason) + + if "pyarrow_table" in str(constructor_eager) and PYARROW_VERSION < ( + 15, + 0, + ): # pragma: no cover + reason = ( + "pyarrow.lib.ArrowNotImplementedError: Unsupported cast from string to " + "dictionary using function cast_dictionary" + ) + pytest.skip(reason=reason) + + if "pyarrow_table" in str(constructor_eager): + # Replace Enum with Categorical, since Pyarrow does not support Enum + schema = {**schema, "enum": nw.Categorical()} + + df = nw.from_native(constructor_eager(data), eager_only=True) + for name, dtype in schema.items(): + assert_series_equal(df[name].cast(dtype), df[name].cast(dtype)) + + +def test_implementation_mismatch() -> None: + """Test that different implementations raise an error.""" + pytest.importorskip("pandas") + pytest.importorskip("pyarrow") + + import pandas as pd + import pyarrow as pa + + with _assertion_error("implementation mismatch"): + assert_series_equal( + series_from_native(pd.Series([1])), + series_from_native(pa.chunked_array([[2]])), # type: ignore[misc] # pyright: ignore[reportArgumentType] + ) + + +@pytest.mark.parametrize( + ("setup_fn", "error_msg"), + [ + (lambda s: (s, s.head(2)), "length mismatch"), + (lambda s: (s.cast(nw.UInt32()), s.cast(nw.Int64())), "dtype mismatch"), + (lambda s: (s.rename("foo"), s.rename("bar")), "name mismatch"), + ], +) +def test_metadata_checks( + constructor_eager: ConstructorEager, setup_fn: SetupFn, error_msg: str +) -> None: + """Test metadata validation (length, dtype, name).""" + series = nw.from_native(constructor_eager({"a": [1, 2, 3]}), eager_only=True)["a"] + left, right = setup_fn(series) + + with _assertion_error(error_msg): + assert_series_equal(left, right) + + +@pytest.mark.parametrize( + ("setup_fn", "error_msg", "check_dtypes", "check_names"), + [ + (lambda s: (s, s.cast(nw.UInt32())), "dtype mismatch", True, False), + (lambda s: (s, s.cast(nw.UInt32()).rename("baz")), "dtype mismatch", True, True), + (lambda s: (s, s.rename("baz")), "name mismatch", False, True), + ], +) +def test_metadata_checks_with_flags( + constructor_eager: ConstructorEager, + setup_fn: SetupFn, + error_msg: str, + *, + check_dtypes: bool, + check_names: bool, +) -> None: + """Test the effect of check_dtypes and check_names flags.""" + series = nw.from_native(constructor_eager({"a": [1, 2, 3]}), eager_only=True)["a"] + left, right = setup_fn(series) + + with _assertion_error(error_msg): + assert_series_equal( + left, right, check_dtypes=check_dtypes, check_names=check_names + ) + + assert_series_equal(left, right, check_dtypes=False, check_names=False) + + +@pytest.mark.parametrize( + ("dtype", "check_order", "context"), + [ + (nw.List(nw.Int32()), False, pytest.raises(NotImplementedError)), + (nw.List(nw.Int32()), True, does_not_raise()), + (nw.Int32(), False, does_not_raise()), + (nw.Int32(), True, does_not_raise()), + ], +) +def test_check_order( + constructor_eager: ConstructorEager, + dtype: nw.dtypes.DType, + *, + check_order: bool, + context: AbstractContextManager[Any], +) -> None: + """Test check_order behavior with nested and simple data.""" + if ( + "pandas" in str(constructor_eager) + and PANDAS_VERSION < (2, 2) + and dtype.is_nested() + ): # pragma: no cover + reason = "Pandas too old for nested dtypes" + pytest.skip(reason=reason) + + data: list[Any] = [[1, 2, 3]] if dtype.is_nested() else [1, 2, 3] + frame = nw.from_native(constructor_eager({"a": data}), eager_only=True) + left = right = frame["a"].cast(dtype) + + with context: + assert_series_equal(left, right, check_order=check_order, check_names=False) + + +@pytest.mark.parametrize( + "null_data", + [ + {"left": ["x", "y", None], "right": ["x", None, "y"]}, # Different null position + {"left": ["x", None, None], "right": [None, "x", "y"]}, # Different null counts + ], +) +def test_null_mismatch(constructor_eager: ConstructorEager, null_data: Data) -> None: + """Test null value mismatch detection.""" + frame = nw.from_native(constructor_eager(null_data), eager_only=True) + left, right = frame["left"], frame["right"] + with _assertion_error("null value mismatch"): + assert_series_equal(left, right, check_names=False) + + +@pytest.mark.parametrize( + ("check_exact", "abs_tol", "rel_tol", "context"), + [ + (True, 1e-3, 1e-3, _assertion_error("exact value mismatch")), + (False, 1e-3, 1e-3, _assertion_error("values not within tolerance")), + (False, 2e-1, 2e-1, does_not_raise()), + ], +) +def test_numeric( + constructor_eager: ConstructorEager, + *, + check_exact: bool, + abs_tol: float, + rel_tol: float, + context: AbstractContextManager[Any], +) -> None: + data = { + "left": [1.0, float("nan"), float("inf"), None, 1.1], + "right": [1.01, float("nan"), float("inf"), None, 1.11], + } + + frame = nw.from_native(constructor_eager(data), eager_only=True) + left, right = frame["left"], frame["right"] + with context: + assert_series_equal( + left, + right, + check_names=False, + check_exact=check_exact, + abs_tol=abs_tol, + rel_tol=rel_tol, + ) + + +@pytest.mark.parametrize( + ("l_vals", "r_vals", "check_exact", "context", "dtype"), + [ + ( + [["foo", "bar"]], + [["foo", None]], + True, + _assertion_error("nested value mismatch"), + nw.List(nw.String()), + ), + ( + [["foo", "bar"]], + [["foo", None]], + True, + _assertion_error("nested value mismatch"), + nw.Array(nw.String(), 2), + ), + ( + [[0.0, 0.1]], + [[0.1, 0.1]], + True, + _assertion_error("nested value mismatch"), + nw.List(nw.Float32()), + ), + ( + [[0.0, 0.1]], + [[0.1, 0.1]], + True, + _assertion_error("nested value mismatch"), + nw.Array(nw.Float32(), 2), + ), + ([[0.0, 1e-10]], [[1e-10, 0.0]], False, does_not_raise(), nw.List(nw.Float64())), + ( + [[0.0, 1e-10]], + [[1e-10, 0.0]], + False, + does_not_raise(), + nw.Array(nw.Float64(), 2), + ), + ], +) +def test_list_like( + constructor_eager: ConstructorEager, + l_vals: list[list[Any]], + r_vals: list[list[Any]], + *, + check_exact: bool, + context: AbstractContextManager[Any], + dtype: nw.dtypes.DType, +) -> None: + if "pandas" in str(constructor_eager) and PANDAS_VERSION < (2, 2): # pragma: no cover + reason = "Pandas too old for nested dtypes" + pytest.skip(reason=reason) + + if ( + "pyarrow_table" in str(constructor_eager) + and PYARROW_VERSION < (14, 0) + and dtype == nw.Array + ): # pragma: no cover + reason = ( + "pyarrow.lib.ArrowNotImplementedError: Unsupported cast from " + "list to fixed_size_list using function cast_fixed_size_list" + ) + pytest.skip(reason=reason) + + data = {"left": l_vals, "right": r_vals} + frame = nw.from_native(constructor_eager(data), eager_only=True) + left, right = frame["left"].cast(dtype), frame["right"].cast(dtype) + with context: + assert_series_equal(left, right, check_names=False, check_exact=check_exact) + + +@pytest.mark.parametrize( + ("l_vals", "r_vals", "check_exact", "context"), + [ + ( + [{"a": 0.0, "b": ["orca"]}, None], + [{"a": 1e-10, "b": ["orca"]}, None], + True, + _assertion_error("exact value mismatch"), + ), + ( + [{"a": 0.0, "b": ["beluga"]}, None], + [{"a": 0.0, "b": ["orca"]}, None], + False, + _assertion_error("exact value mismatch"), + ), + ( + [{"a": 0.0, "b": ["orca"]}, None], + [{"a": 1e-10, "b": ["orca"]}, None], + False, + does_not_raise(), + ), + ], +) +def test_struct( + constructor_eager: ConstructorEager, + l_vals: list[dict[str, Any]], + r_vals: list[dict[str, Any]], + *, + check_exact: bool, + context: AbstractContextManager[Any], +) -> None: + if "pandas" in str(constructor_eager) and PANDAS_VERSION < (2, 2): # pragma: no cover + reason = "Pandas too old for nested dtypes" + pytest.skip(reason=reason) + + dtype = nw.Struct({"a": nw.Float32(), "b": nw.List(nw.String())}) + data = {"left": l_vals, "right": r_vals} + frame = nw.from_native(constructor_eager(data), eager_only=True) + left, right = frame["left"].cast(dtype), frame["right"].cast(dtype) + with context: + assert_series_equal(left, right, check_names=False, check_exact=check_exact) + + +def test_non_nw_series() -> None: + pytest.importorskip("pandas") + + import pandas as pd + + with pytest.raises( + TypeError, match=re.escape("Expected `narwhals.Series` instance, found") + ): + assert_series_equal( + left=pd.Series([1]), # type: ignore[arg-type] + right=pd.Series([2]), # type: ignore[arg-type] + ) + + +@pytest.mark.parametrize( + ("categorical_as_str", "context"), + [ + (True, does_not_raise()), + ( + False, + pytest.raises( + AssertionError, + match="Cannot compare categoricals coming from different sources", + ), + ), + ], +) +def test_categorical_as_str( + request: pytest.FixtureRequest, + constructor_eager: ConstructorEager, + *, + categorical_as_str: bool, + context: AbstractContextManager[Any], +) -> None: + if ( + "polars" in str(constructor_eager) + and POLARS_VERSION >= (1, 32) + and not categorical_as_str + ): + # https://github.com/pola-rs/polars/pull/23016 removed StringCache, it still + # exists but it does nothing in python. + request.applymarker(pytest.mark.xfail) + + if "pyarrow_table" in str(constructor_eager) and not categorical_as_str: + # pyarrow dictionary dtype compares values, not the encoding. + request.applymarker(pytest.mark.xfail) + + if "pyarrow_table" in str(constructor_eager) and PYARROW_VERSION < ( + 15, + 0, + ): # pragma: no cover + reason = ( + "pyarrow.lib.ArrowNotImplementedError: Unsupported cast from string to " + "dictionary using function cast_dictionary" + ) + pytest.skip(reason=reason) + + data = { + "left": ["beluga", "dolphin", "narwhal", "orca"], + "right": ["unicorn", "orca", "narwhal", "orca"], + } + frame = nw.from_native(constructor_eager(data), eager_only=True) + left = frame["left"].cast(nw.Categorical())[2:] + right = frame["right"].cast(nw.Categorical())[2:] + + with context: + assert_series_equal( + left, right, check_names=False, categorical_as_str=categorical_as_str + ) diff --git a/tests/testing/conftest.py b/tests/testing/conftest.py new file mode 100644 index 0000000000..9ac4419f72 --- /dev/null +++ b/tests/testing/conftest.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw + +if TYPE_CHECKING: + from narwhals.typing import IntoSchema + from tests.conftest import Data + + +@pytest.fixture +def schema() -> IntoSchema: + return { + "int": nw.Int32(), + "float": nw.Float32(), + "str": nw.String(), + "categorical": nw.Categorical(), + "enum": nw.Enum(["beluga", "narwhal", "orca"]), + "bool": nw.Boolean(), + "datetime": nw.Datetime(), + "date": nw.Date(), + "time": nw.Time(), + "duration": nw.Duration(), + "binary": nw.Binary(), + "list": nw.List(nw.Float32()), + "array": nw.Array(nw.Int32(), shape=2), + "struct": nw.Struct({"a": nw.Int64(), "b": nw.List(nw.String())}), + } + + +@pytest.fixture +def data() -> Data: + return { + "int": [1, 2, 3, 4], + "float": [1.0, float("nan"), float("inf"), None], + "str": ["beluga", "narwhal", "orca", None], + "categorical": ["beluga", "narwhal", "beluga", None], + "enum": ["beluga", "narwhal", "orca", "narwhal"], + "bool": [True, False, True, None], + "datetime": [ + datetime(2025, 1, 1, 12), + datetime(2025, 1, 2, 12), + datetime(2025, 1, 3, 12), + None, + ], + "date": [date(2025, 1, 1), date(2025, 1, 2), date(2025, 1, 3), None], + "time": [time(9, 0), time(9, 1, 10), time(9, 2), None], + "duration": [ + timedelta(seconds=1), + timedelta(seconds=2), + timedelta(seconds=3), + None, + ], + "binary": [b"foo", b"bar", b"baz", None], + "list": [[1.0, float("nan")], [], [None], None], + "array": [[1, 2], [3, 4], [5, 6], None], + "struct": [ + {"a": 1, "b": ["narwhal", "beluga"]}, + {"a": 2, "b": ["orca"]}, + {"a": 3, "b": [None]}, + None, + ], + }