Skip to content
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
ffb9421
WIP: assert series equal
FBruzzesi Aug 5, 2025
5839b1c
WIP
FBruzzesi Aug 5, 2025
df92a53
merge main
FBruzzesi Aug 10, 2025
6a0ae06
wait for is_close method
FBruzzesi Aug 10, 2025
bcab59c
wrong invert in check_exact
FBruzzesi Aug 10, 2025
69fc160
folder structure as polars
FBruzzesi Aug 10, 2025
7a05d17
merge main
FBruzzesi Aug 13, 2025
a07822f
WIP unit tests
FBruzzesi Aug 13, 2025
a1c7e9e
handle nested dtypes with recursion
FBruzzesi Aug 13, 2025
e7637ce
factor out nested checks
FBruzzesi Aug 13, 2025
0c137ec
merge main
FBruzzesi Aug 13, 2025
fe30e5b
coverage
FBruzzesi Aug 13, 2025
774419a
line length for docstring
FBruzzesi Aug 13, 2025
20341e4
refactor into subfunctions
FBruzzesi Aug 13, 2025
f3f48b8
refactor tests
FBruzzesi Aug 13, 2025
b61a5a4
add docpage
FBruzzesi Aug 13, 2025
d21879d
merge main
FBruzzesi Aug 13, 2025
be0ae09
merge main
FBruzzesi Aug 15, 2025
f161929
skip tests with nested dtype for old pandas
FBruzzesi Aug 15, 2025
bae0a9a
skip pyarrow old versions
FBruzzesi Aug 15, 2025
546110c
skip pyarrow old version for arrays
FBruzzesi Aug 15, 2025
b6c24d0
Merge branch 'main' into testing
FBruzzesi Aug 17, 2025
afb47a8
merge main
FBruzzesi Aug 18, 2025
fde4fc4
merge main
FBruzzesi Aug 19, 2025
3726fe7
use zip_strict
FBruzzesi Aug 19, 2025
94ec0dc
merge main
FBruzzesi Aug 23, 2025
d73e475
merge main, type ignore
FBruzzesi Aug 23, 2025
6bba76b
merge main
FBruzzesi Aug 26, 2025
b28099f
less walrus ops, add example in docstrings
FBruzzesi Aug 26, 2025
44620db
Merge branch 'main' into testing
dangotbanned Aug 28, 2025
96a9ab1
test(suggestion): Only allow `nw.Series`
dangotbanned Aug 29, 2025
288a961
refactor: factor-in `_maybe_apply_preprocessing`
dangotbanned Aug 29, 2025
5d7040b
add coverage for type error
FBruzzesi Aug 30, 2025
72b7cf0
merge main
FBruzzesi Sep 4, 2025
4c9f8d9
categorical case
FBruzzesi Sep 4, 2025
af5f6d4
correct xfail cases
FBruzzesi Sep 4, 2025
ce7e41e
Merge branch 'main' into testing
FBruzzesi Sep 5, 2025
14b56cd
merge main
FBruzzesi Sep 5, 2025
9a4b605
skip old pyarrow
FBruzzesi Sep 5, 2025
6433213
merge main
FBruzzesi Sep 7, 2025
43646b4
update docstring to use narwhals.Series
FBruzzesi Sep 8, 2025
b81b781
fix(typing): Avoid type ignore on `CheckFn`
dangotbanned Sep 10, 2025
1fa0643
Merge remote-tracking branch 'upstream/main' into testing
dangotbanned Sep 10, 2025
a0592e1
refactor(suggestion): Ensure consistent reasons?
dangotbanned Sep 10, 2025
18b0652
merge main
FBruzzesi Sep 14, 2025
153d51f
rm assert_frame_equal
FBruzzesi Sep 14, 2025
fc4390b
Merge branch 'main' into testing
FBruzzesi Oct 6, 2025
67c4848
merge main
FBruzzesi Oct 7, 2025
7c2c8c2
add NOTE on SeriesDetail
FBruzzesi Oct 7, 2025
3da9399
Merge branch 'main' into testing
dangotbanned Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ repos:
name: don't import from narwhals.dtypes (use `Version.dtypes` instead)
entry: |
(?x)
import\ narwhals.dtypes|
from\ narwhals\ import\ dtypes|
from\ narwhals.dtypes\ import\ [^D_]+|
import\ narwhals.stable.v1.dtypes|
from\ narwhals.stable\.v.\ import\ dtypes|
from\ narwhals.stable\.v.\.dtypes\ import
import\ narwhals(\.stable\.v\d)?\.dtypes|
from\ narwhals(\.stable\.v\d)?\ import\ dtypes|
^from\ narwhals(\.stable\.v\d)?\.dtypes\ import
\ (DType,\ )?
((Datetime|Duration|Enum)(,\ )?)+
((,\ )?DType)?
language: pygrep
files: ^narwhals/
exclude: |
Expand Down
7 changes: 7 additions & 0 deletions docs/api-reference/testing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# `narwhals.testing`

::: narwhals.testing
handler: python
options:
members:
- assert_series_equal
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ nav:
- api-reference/dtypes.md
- api-reference/exceptions.md
- api-reference/selectors.md
- api-reference/testing.md
- api-reference/typing.md
- api-reference/utils.md
- This: this.md
Expand Down
6 changes: 6 additions & 0 deletions narwhals/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from __future__ import annotations

from narwhals.testing.asserts.frame import assert_frame_equal
from narwhals.testing.asserts.series import assert_series_equal

__all__ = ("assert_frame_equal", "assert_series_equal")
Empty file.
22 changes: 22 additions & 0 deletions narwhals/testing/asserts/frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from narwhals.typing import IntoFrameT


def assert_frame_equal(
left: IntoFrameT,
right: IntoFrameT,
*,
check_row_order: bool = True,
check_column_order: bool = True,
check_dtypes: bool = True,
check_exact: bool = False,
rel_tol: float = 1e-05,
abs_tol: float = 1e-08,
categorical_as_str: bool = False,
) -> None:
msg = "TODO" # pragma: no cover
raise NotImplementedError(msg)
291 changes: 291 additions & 0 deletions narwhals/testing/asserts/series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any, Callable

from narwhals._utils import qualified_type_name, zip_strict
from narwhals.dependencies import is_narwhals_series
from narwhals.dtypes import Array, Boolean, Categorical, List, String, Struct
from narwhals.functions import new_series
from narwhals.testing.asserts.utils import raise_series_assertion_error

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from narwhals.series import Series
from narwhals.typing import IntoSeriesT, SeriesT

CheckFn: TypeAlias = Callable[[Series[Any], Series[Any]], None]


def assert_series_equal(
left: Series[IntoSeriesT],
right: Series[IntoSeriesT],
*,
check_dtypes: bool = True,
check_names: bool = True,
check_order: bool = True,
check_exact: bool = False,
rel_tol: float = 1e-05,
abs_tol: float = 1e-08,
categorical_as_str: bool = False,
) -> None:
"""Assert that the left and right Series are equal.

Raises a detailed `AssertionError` if the Series differ.
This function is intended for use in unit tests.

Arguments:
left: The first Series to compare.
right: The second Series to compare.
check_dtypes: Requires data types to match.
check_names: Requires names to match.
check_order: Requires elements to appear in the same order.
check_exact: Requires float values to match exactly. If set to `False`, values are
considered equal when within tolerance of each other (see `rel_tol` and
`abs_tol`). Only affects columns with a Float data type.
rel_tol: Relative tolerance for inexact checking, given as a fraction of the
values in `right`.
abs_tol: Absolute tolerance for inexact checking.
categorical_as_str: Cast categorical columns to string before comparing.
Enabling this helps compare columns that do not share the same string cache.

Examples:
>>> import pandas as pd
>>> import narwhals as nw
>>> from narwhals.testing import assert_series_equal
>>> s1 = nw.from_native(pd.Series([1, 2, 3]), series_only=True)
>>> s2 = nw.from_native(pd.Series([1, 5, 3]), series_only=True)
>>> assert_series_equal(s1, s2) # doctest: +ELLIPSIS
Traceback (most recent call last):
...
AssertionError: Series are different (exact value mismatch)
[left]:
┌───────────────┐
|Narwhals Series|
|---------------|
| 0 1 |
| 1 2 |
| 2 3 |
| dtype: int64 |
└───────────────┘
[right]:
┌───────────────┐
|Narwhals Series|
|---------------|
| 0 1 |
| 1 5 |
| 2 3 |
| dtype: int64 |
└───────────────┘
"""
__tracebackhide__ = True

if any(not is_narwhals_series(obj) for obj in (left, right)):
msg = (
"Expected `narwhals.Series` instance, found:\n"
f"[left]: {qualified_type_name(type(left))}\n"
f"[right]: {qualified_type_name(type(right))}\n\n"
"Hint: Use `nw.from_native(obj, series_only=True) to convert each native "
"object into a `narwhals.Series` first."
)
raise TypeError(msg)

_check_metadata(left, right, check_dtypes=check_dtypes, check_names=check_names)

if not check_order:
if left.dtype.is_nested():
msg = "`check_order=False` is not supported (yet) with nested data type."
raise NotImplementedError(msg)
left, right = left.sort(), right.sort()

left_vals, right_vals = _check_null_values(left, right)

if check_exact or not left.dtype.is_float():
_check_exact_values(
left_vals,
right_vals,
check_dtypes=check_dtypes,
check_exact=check_exact,
rel_tol=rel_tol,
abs_tol=abs_tol,
categorical_as_str=categorical_as_str,
)
else:
_check_approximate_values(left_vals, right_vals, rel_tol=rel_tol, abs_tol=abs_tol)


def _check_metadata(
left: SeriesT, right: SeriesT, *, check_dtypes: bool, check_names: bool
) -> None:
"""Check metadata information: implementation, length, dtype, and names."""
left_impl, right_impl = left.implementation, right.implementation
if left_impl != right_impl:
raise_series_assertion_error("implementation mismatch", left_impl, right_impl)

left_len, right_len = len(left), len(right)
if left_len != right_len:
raise_series_assertion_error("length mismatch", left_len, right_len)

left_dtype, right_dtype = left.dtype, right.dtype
if check_dtypes and left_dtype != right_dtype:
raise_series_assertion_error("dtype mismatch", left_dtype, right_dtype)

left_name, right_name = left.name, right.name
if check_names and left_name != right_name:
raise_series_assertion_error("name mismatch", left_name, right_name)


def _check_null_values(left: SeriesT, right: SeriesT) -> tuple[SeriesT, SeriesT]:
"""Check null value consistency and return non-null values."""
left_null_count, right_null_count = left.null_count(), right.null_count()
left_null_mask, right_null_mask = left.is_null(), right.is_null()

if left_null_count != right_null_count or (left_null_mask != right_null_mask).any():
raise_series_assertion_error(
"null value mismatch", left_null_count, right_null_count
)

return left.filter(~left_null_mask), right.filter(~right_null_mask)


def _check_exact_values(
left: SeriesT,
right: SeriesT,
*,
check_dtypes: bool,
check_exact: bool,
rel_tol: float,
abs_tol: float,
categorical_as_str: bool,
) -> None:
"""Check exact value equality for various data types."""
left_impl = left.implementation
left_dtype, right_dtype = left.dtype, right.dtype

is_not_equal_mask: Series[Any]
if left_dtype.is_numeric():
# For _all_ numeric dtypes, we can use `is_close` with 0-tolerances to handle
# inf and nan values out of the box.
is_not_equal_mask = ~left.is_close(right, rel_tol=0, abs_tol=0, nans_equal=True)
elif (
isinstance(left_dtype, (Array, List)) and isinstance(right_dtype, (Array, List))
) and left_dtype == right_dtype:
check_fn = partial(
assert_series_equal,
check_dtypes=check_dtypes,
check_names=False,
check_order=True,
check_exact=check_exact,
rel_tol=rel_tol,
abs_tol=abs_tol,
categorical_as_str=categorical_as_str,
)
_check_list_like(left, right, left_dtype, right_dtype, check_fn=check_fn)
# If `_check_list_like` didn't raise, then every nested element is equal
is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl)
elif isinstance(left_dtype, Struct) and isinstance(right_dtype, Struct):
check_fn = partial(
assert_series_equal,
check_dtypes=True,
check_names=True,
check_order=True,
check_exact=check_exact,
rel_tol=rel_tol,
abs_tol=abs_tol,
categorical_as_str=categorical_as_str,
)
_check_struct(left, right, left_dtype, right_dtype, check_fn=check_fn)
# If `_check_struct` didn't raise, then every nested element is equal
is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl)
elif isinstance(left_dtype, Categorical) and isinstance(right_dtype, Categorical):
# If `_check_categorical` didn't raise, then the categories sources/encodings are
# the same, and we can use equality
_not_equal = _check_categorical(
left, right, categorical_as_str=categorical_as_str
)
is_not_equal_mask = new_series(
"", [_not_equal], dtype=Boolean(), backend=left_impl
)
else:
is_not_equal_mask = left != right

if is_not_equal_mask.any():
raise_series_assertion_error("exact value mismatch", left, right)


def _check_approximate_values(
left: SeriesT, right: SeriesT, *, rel_tol: float, abs_tol: float
) -> None:
"""Check approximate value equality with tolerance."""
is_not_close_mask = ~left.is_close(
right, rel_tol=rel_tol, abs_tol=abs_tol, nans_equal=True
)

if is_not_close_mask.any():
raise_series_assertion_error(
"values not within tolerance",
left.filter(is_not_close_mask),
right.filter(is_not_close_mask),
)


def _check_list_like(
left_vals: SeriesT,
right_vals: SeriesT,
left_dtype: List | Array,
right_dtype: List | Array,
check_fn: CheckFn,
) -> None:
# Check row by row after transforming each array/list into a new series.
# Notice that order within the array/list must be the same, regardless of
# `check_order` value at the top level.
impl = left_vals.implementation
try:
for left_val, right_val in zip_strict(left_vals, right_vals):
check_fn(
new_series("", values=left_val, dtype=left_dtype.inner, backend=impl),
new_series("", values=right_val, dtype=right_dtype.inner, backend=impl),
)
except AssertionError:
raise_series_assertion_error("nested value mismatch", left_vals, right_vals)


def _check_struct(
left_vals: SeriesT,
right_vals: SeriesT,
left_dtype: Struct,
right_dtype: Struct,
check_fn: CheckFn,
) -> None:
# Check field by field as a separate column.
# Notice that for struct's polars raises if:
# * field names are different but values are equal
# * dtype differs, regardless of `check_dtypes=False`
# * order applies only at top level
try:
for left_field, right_field in zip_strict(left_dtype.fields, right_dtype.fields):
check_fn(
left_vals.struct.field(left_field.name),
right_vals.struct.field(right_field.name),
)
except AssertionError:
raise_series_assertion_error("exact value mismatch", left_vals, right_vals)


def _check_categorical(
left_vals: SeriesT, right_vals: SeriesT, *, categorical_as_str: bool
) -> bool:
"""Try to compare if any element of categorical series' differ.

Inability to compare means that the encoding is different, and an exception is raised.
"""
if categorical_as_str:
left_vals, right_vals = left_vals.cast(String()), right_vals.cast(String())

try:
return (left_vals != right_vals).any()
except Exception as exc:
msg = "Cannot compare categoricals coming from different sources."
# TODO(FBruzzesi): Improve error message?
raise AssertionError(msg) from exc
42 changes: 42 additions & 0 deletions narwhals/testing/asserts/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

from narwhals.dependencies import is_narwhals_series

if TYPE_CHECKING:
from typing_extensions import Never, TypeAlias

SeriesDetail: TypeAlias = Literal[
"implementation mismatch",
"length mismatch",
"dtype mismatch",
"name mismatch",
"null value mismatch",
"exact value mismatch",
"values not within tolerance",
"nested value mismatch",
]


def raise_assertion_error(
objects: str, detail: str, left: Any, right: Any, *, cause: Exception | None = None
) -> Never:
"""Raise a detailed assertion error."""
__tracebackhide__ = True

trailing_left = "\n" if is_narwhals_series(left) else " "
trailing_right = "\n" if is_narwhals_series(right) else " "

msg = (
f"{objects} are different ({detail})\n"
f"[left]:{trailing_left}{left}\n"
f"[right]:{trailing_right}{right}"
)
raise AssertionError(msg) from cause


def raise_series_assertion_error(
detail: SeriesDetail, left: Any, right: Any, *, cause: Exception | None = None
) -> Never:
raise_assertion_error("Series", detail, left, right, cause=cause)
Empty file added tests/testing/__init__.py
Empty file.
Loading
Loading