-
Notifications
You must be signed in to change notification settings - Fork 171
feat: Add assert_series_equal
#2983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 44 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
ffb9421
WIP: assert series equal
FBruzzesi 5839b1c
WIP
FBruzzesi df92a53
merge main
FBruzzesi 6a0ae06
wait for is_close method
FBruzzesi bcab59c
wrong invert in check_exact
FBruzzesi 69fc160
folder structure as polars
FBruzzesi 7a05d17
merge main
FBruzzesi a07822f
WIP unit tests
FBruzzesi a1c7e9e
handle nested dtypes with recursion
FBruzzesi e7637ce
factor out nested checks
FBruzzesi 0c137ec
merge main
FBruzzesi fe30e5b
coverage
FBruzzesi 774419a
line length for docstring
FBruzzesi 20341e4
refactor into subfunctions
FBruzzesi f3f48b8
refactor tests
FBruzzesi b61a5a4
add docpage
FBruzzesi d21879d
merge main
FBruzzesi be0ae09
merge main
FBruzzesi f161929
skip tests with nested dtype for old pandas
FBruzzesi bae0a9a
skip pyarrow old versions
FBruzzesi 546110c
skip pyarrow old version for arrays
FBruzzesi b6c24d0
Merge branch 'main' into testing
FBruzzesi afb47a8
merge main
FBruzzesi fde4fc4
merge main
FBruzzesi 3726fe7
use zip_strict
FBruzzesi 94ec0dc
merge main
FBruzzesi d73e475
merge main, type ignore
FBruzzesi 6bba76b
merge main
FBruzzesi b28099f
less walrus ops, add example in docstrings
FBruzzesi 44620db
Merge branch 'main' into testing
dangotbanned 96a9ab1
test(suggestion): Only allow `nw.Series`
dangotbanned 288a961
refactor: factor-in `_maybe_apply_preprocessing`
dangotbanned 5d7040b
add coverage for type error
FBruzzesi 72b7cf0
merge main
FBruzzesi 4c9f8d9
categorical case
FBruzzesi af5f6d4
correct xfail cases
FBruzzesi ce7e41e
Merge branch 'main' into testing
FBruzzesi 14b56cd
merge main
FBruzzesi 9a4b605
skip old pyarrow
FBruzzesi 6433213
merge main
FBruzzesi 43646b4
update docstring to use narwhals.Series
FBruzzesi b81b781
fix(typing): Avoid type ignore on `CheckFn`
dangotbanned 1fa0643
Merge remote-tracking branch 'upstream/main' into testing
dangotbanned a0592e1
refactor(suggestion): Ensure consistent reasons?
dangotbanned 18b0652
merge main
FBruzzesi 153d51f
rm assert_frame_equal
FBruzzesi fc4390b
Merge branch 'main' into testing
FBruzzesi 67c4848
merge main
FBruzzesi 7c2c8c2
add NOTE on SeriesDetail
FBruzzesi 3da9399
Merge branch 'main' into testing
dangotbanned File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # `narwhals.testing` | ||
|
|
||
| ::: narwhals.testing | ||
| handler: python | ||
| options: | ||
| members: | ||
| - assert_series_equal |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from narwhals.testing.asserts.frame import assert_frame_equal | ||
| from narwhals.testing.asserts.series import assert_series_equal | ||
|
|
||
| __all__ = ("assert_frame_equal", "assert_series_equal") |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from narwhals.typing import IntoFrameT | ||
|
|
||
|
|
||
| def assert_frame_equal( | ||
| left: IntoFrameT, | ||
| right: IntoFrameT, | ||
| *, | ||
| check_row_order: bool = True, | ||
| check_column_order: bool = True, | ||
| check_dtypes: bool = True, | ||
| check_exact: bool = False, | ||
| rel_tol: float = 1e-05, | ||
| abs_tol: float = 1e-08, | ||
| categorical_as_str: bool = False, | ||
| ) -> None: | ||
| msg = "TODO" # pragma: no cover | ||
| raise NotImplementedError(msg) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,291 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from functools import partial | ||
| from typing import TYPE_CHECKING, Any, Callable | ||
|
|
||
| from narwhals._utils import qualified_type_name, zip_strict | ||
| from narwhals.dependencies import is_narwhals_series | ||
| from narwhals.dtypes import Array, Boolean, Categorical, List, String, Struct | ||
| from narwhals.functions import new_series | ||
| from narwhals.testing.asserts.utils import raise_series_assertion_error | ||
|
|
||
| if TYPE_CHECKING: | ||
| from typing_extensions import TypeAlias | ||
|
|
||
| from narwhals.series import Series | ||
| from narwhals.typing import IntoSeriesT, SeriesT | ||
|
|
||
| CheckFn: TypeAlias = Callable[[Series[Any], Series[Any]], None] | ||
|
|
||
|
|
||
| def assert_series_equal( | ||
| left: Series[IntoSeriesT], | ||
| right: Series[IntoSeriesT], | ||
| *, | ||
| check_dtypes: bool = True, | ||
| check_names: bool = True, | ||
| check_order: bool = True, | ||
| check_exact: bool = False, | ||
| rel_tol: float = 1e-05, | ||
| abs_tol: float = 1e-08, | ||
| categorical_as_str: bool = False, | ||
| ) -> None: | ||
| """Assert that the left and right Series are equal. | ||
|
|
||
| Raises a detailed `AssertionError` if the Series differ. | ||
| This function is intended for use in unit tests. | ||
|
|
||
| Arguments: | ||
| left: The first Series to compare. | ||
| right: The second Series to compare. | ||
| check_dtypes: Requires data types to match. | ||
| check_names: Requires names to match. | ||
| check_order: Requires elements to appear in the same order. | ||
| check_exact: Requires float values to match exactly. If set to `False`, values are | ||
| considered equal when within tolerance of each other (see `rel_tol` and | ||
| `abs_tol`). Only affects columns with a Float data type. | ||
| rel_tol: Relative tolerance for inexact checking, given as a fraction of the | ||
| values in `right`. | ||
| abs_tol: Absolute tolerance for inexact checking. | ||
| categorical_as_str: Cast categorical columns to string before comparing. | ||
| Enabling this helps compare columns that do not share the same string cache. | ||
|
|
||
| Examples: | ||
| >>> import pandas as pd | ||
| >>> import narwhals as nw | ||
| >>> from narwhals.testing import assert_series_equal | ||
| >>> s1 = nw.from_native(pd.Series([1, 2, 3]), series_only=True) | ||
| >>> s2 = nw.from_native(pd.Series([1, 5, 3]), series_only=True) | ||
| >>> assert_series_equal(s1, s2) # doctest: +ELLIPSIS | ||
| Traceback (most recent call last): | ||
| ... | ||
| AssertionError: Series are different (exact value mismatch) | ||
| [left]: | ||
| ┌───────────────┐ | ||
| |Narwhals Series| | ||
| |---------------| | ||
| | 0 1 | | ||
| | 1 2 | | ||
| | 2 3 | | ||
| | dtype: int64 | | ||
| └───────────────┘ | ||
| [right]: | ||
| ┌───────────────┐ | ||
| |Narwhals Series| | ||
| |---------------| | ||
| | 0 1 | | ||
| | 1 5 | | ||
| | 2 3 | | ||
| | dtype: int64 | | ||
| └───────────────┘ | ||
| """ | ||
| __tracebackhide__ = True | ||
|
|
||
| if any(not is_narwhals_series(obj) for obj in (left, right)): | ||
| msg = ( | ||
| "Expected `narwhals.Series` instance, found:\n" | ||
| f"[left]: {qualified_type_name(type(left))}\n" | ||
| f"[right]: {qualified_type_name(type(right))}\n\n" | ||
| "Hint: Use `nw.from_native(obj, series_only=True) to convert each native " | ||
| "object into a `narwhals.Series` first." | ||
| ) | ||
| raise TypeError(msg) | ||
|
|
||
| _check_metadata(left, right, check_dtypes=check_dtypes, check_names=check_names) | ||
|
|
||
| if not check_order: | ||
| if left.dtype.is_nested(): | ||
| msg = "`check_order=False` is not supported (yet) with nested data type." | ||
| raise NotImplementedError(msg) | ||
| left, right = left.sort(), right.sort() | ||
|
|
||
| left_vals, right_vals = _check_null_values(left, right) | ||
|
|
||
| if check_exact or not left.dtype.is_float(): | ||
| _check_exact_values( | ||
| left_vals, | ||
| right_vals, | ||
| check_dtypes=check_dtypes, | ||
| check_exact=check_exact, | ||
| rel_tol=rel_tol, | ||
| abs_tol=abs_tol, | ||
| categorical_as_str=categorical_as_str, | ||
dangotbanned marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| else: | ||
| _check_approximate_values(left_vals, right_vals, rel_tol=rel_tol, abs_tol=abs_tol) | ||
|
|
||
|
|
||
| def _check_metadata( | ||
| left: SeriesT, right: SeriesT, *, check_dtypes: bool, check_names: bool | ||
| ) -> None: | ||
| """Check metadata information: implementation, length, dtype, and names.""" | ||
dangotbanned marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| left_impl, right_impl = left.implementation, right.implementation | ||
| if left_impl != right_impl: | ||
| raise_series_assertion_error("implementation mismatch", left_impl, right_impl) | ||
|
|
||
| left_len, right_len = len(left), len(right) | ||
| if left_len != right_len: | ||
| raise_series_assertion_error("length mismatch", left_len, right_len) | ||
|
|
||
| left_dtype, right_dtype = left.dtype, right.dtype | ||
| if check_dtypes and left_dtype != right_dtype: | ||
| raise_series_assertion_error("dtype mismatch", left_dtype, right_dtype) | ||
|
|
||
| left_name, right_name = left.name, right.name | ||
| if check_names and left_name != right_name: | ||
| raise_series_assertion_error("name mismatch", left_name, right_name) | ||
|
|
||
|
|
||
| def _check_null_values(left: SeriesT, right: SeriesT) -> tuple[SeriesT, SeriesT]: | ||
| """Check null value consistency and return non-null values.""" | ||
| left_null_count, right_null_count = left.null_count(), right.null_count() | ||
| left_null_mask, right_null_mask = left.is_null(), right.is_null() | ||
|
|
||
| if left_null_count != right_null_count or (left_null_mask != right_null_mask).any(): | ||
| raise_series_assertion_error( | ||
| "null value mismatch", left_null_count, right_null_count | ||
| ) | ||
|
|
||
| return left.filter(~left_null_mask), right.filter(~right_null_mask) | ||
|
|
||
|
|
||
| def _check_exact_values( | ||
| left: SeriesT, | ||
| right: SeriesT, | ||
| *, | ||
| check_dtypes: bool, | ||
| check_exact: bool, | ||
| rel_tol: float, | ||
| abs_tol: float, | ||
| categorical_as_str: bool, | ||
| ) -> None: | ||
| """Check exact value equality for various data types.""" | ||
| left_impl = left.implementation | ||
| left_dtype, right_dtype = left.dtype, right.dtype | ||
|
|
||
| is_not_equal_mask: Series[Any] | ||
| if left_dtype.is_numeric(): | ||
| # For _all_ numeric dtypes, we can use `is_close` with 0-tolerances to handle | ||
| # inf and nan values out of the box. | ||
| is_not_equal_mask = ~left.is_close(right, rel_tol=0, abs_tol=0, nans_equal=True) | ||
| elif ( | ||
| isinstance(left_dtype, (Array, List)) and isinstance(right_dtype, (Array, List)) | ||
| ) and left_dtype == right_dtype: | ||
| check_fn = partial( | ||
| assert_series_equal, | ||
| check_dtypes=check_dtypes, | ||
| check_names=False, | ||
| check_order=True, | ||
| check_exact=check_exact, | ||
| rel_tol=rel_tol, | ||
| abs_tol=abs_tol, | ||
| categorical_as_str=categorical_as_str, | ||
| ) | ||
| _check_list_like(left, right, left_dtype, right_dtype, check_fn=check_fn) | ||
| # If `_check_list_like` didn't raise, then every nested element is equal | ||
| is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl) | ||
| elif isinstance(left_dtype, Struct) and isinstance(right_dtype, Struct): | ||
| check_fn = partial( | ||
| assert_series_equal, | ||
| check_dtypes=True, | ||
| check_names=True, | ||
| check_order=True, | ||
| check_exact=check_exact, | ||
| rel_tol=rel_tol, | ||
| abs_tol=abs_tol, | ||
| categorical_as_str=categorical_as_str, | ||
| ) | ||
| _check_struct(left, right, left_dtype, right_dtype, check_fn=check_fn) | ||
| # If `_check_struct` didn't raise, then every nested element is equal | ||
| is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl) | ||
| elif isinstance(left_dtype, Categorical) and isinstance(right_dtype, Categorical): | ||
| # If `_check_categorical` didn't raise, then the categories sources/encodings are | ||
| # the same, and we can use equality | ||
| _not_equal = _check_categorical( | ||
| left, right, categorical_as_str=categorical_as_str | ||
| ) | ||
| is_not_equal_mask = new_series( | ||
| "", [_not_equal], dtype=Boolean(), backend=left_impl | ||
| ) | ||
| else: | ||
| is_not_equal_mask = left != right | ||
|
|
||
| if is_not_equal_mask.any(): | ||
| raise_series_assertion_error("exact value mismatch", left, right) | ||
|
|
||
|
|
||
| def _check_approximate_values( | ||
| left: SeriesT, right: SeriesT, *, rel_tol: float, abs_tol: float | ||
| ) -> None: | ||
| """Check approximate value equality with tolerance.""" | ||
| is_not_close_mask = ~left.is_close( | ||
| right, rel_tol=rel_tol, abs_tol=abs_tol, nans_equal=True | ||
| ) | ||
|
|
||
| if is_not_close_mask.any(): | ||
| raise_series_assertion_error( | ||
| "values not within tolerance", | ||
| left.filter(is_not_close_mask), | ||
| right.filter(is_not_close_mask), | ||
| ) | ||
|
|
||
|
|
||
| def _check_list_like( | ||
| left_vals: SeriesT, | ||
| right_vals: SeriesT, | ||
| left_dtype: List | Array, | ||
| right_dtype: List | Array, | ||
| check_fn: CheckFn, | ||
| ) -> None: | ||
| # Check row by row after transforming each array/list into a new series. | ||
| # Notice that order within the array/list must be the same, regardless of | ||
| # `check_order` value at the top level. | ||
| impl = left_vals.implementation | ||
| try: | ||
| for left_val, right_val in zip_strict(left_vals, right_vals): | ||
| check_fn( | ||
| new_series("", values=left_val, dtype=left_dtype.inner, backend=impl), | ||
| new_series("", values=right_val, dtype=right_dtype.inner, backend=impl), | ||
| ) | ||
| except AssertionError: | ||
| raise_series_assertion_error("nested value mismatch", left_vals, right_vals) | ||
|
|
||
|
|
||
| def _check_struct( | ||
| left_vals: SeriesT, | ||
| right_vals: SeriesT, | ||
| left_dtype: Struct, | ||
| right_dtype: Struct, | ||
| check_fn: CheckFn, | ||
| ) -> None: | ||
| # Check field by field as a separate column. | ||
| # Notice that for struct's polars raises if: | ||
| # * field names are different but values are equal | ||
| # * dtype differs, regardless of `check_dtypes=False` | ||
| # * order applies only at top level | ||
| try: | ||
| for left_field, right_field in zip_strict(left_dtype.fields, right_dtype.fields): | ||
| check_fn( | ||
| left_vals.struct.field(left_field.name), | ||
| right_vals.struct.field(right_field.name), | ||
| ) | ||
| except AssertionError: | ||
| raise_series_assertion_error("exact value mismatch", left_vals, right_vals) | ||
|
|
||
|
|
||
| def _check_categorical( | ||
| left_vals: SeriesT, right_vals: SeriesT, *, categorical_as_str: bool | ||
| ) -> bool: | ||
| """Try to compare if any element of categorical series' differ. | ||
|
|
||
| Inability to compare means that the encoding is different, and an exception is raised. | ||
| """ | ||
| if categorical_as_str: | ||
| left_vals, right_vals = left_vals.cast(String()), right_vals.cast(String()) | ||
|
|
||
| try: | ||
| return (left_vals != right_vals).any() | ||
| except Exception as exc: | ||
| msg = "Cannot compare categoricals coming from different sources." | ||
| # TODO(FBruzzesi): Improve error message? | ||
| raise AssertionError(msg) from exc | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Any, Literal | ||
|
|
||
| from narwhals.dependencies import is_narwhals_series | ||
|
|
||
| if TYPE_CHECKING: | ||
| from typing_extensions import Never, TypeAlias | ||
|
|
||
| SeriesDetail: TypeAlias = Literal[ | ||
| "implementation mismatch", | ||
| "length mismatch", | ||
| "dtype mismatch", | ||
| "name mismatch", | ||
| "null value mismatch", | ||
| "exact value mismatch", | ||
| "values not within tolerance", | ||
| "nested value mismatch", | ||
| ] | ||
dangotbanned marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def raise_assertion_error( | ||
| objects: str, detail: str, left: Any, right: Any, *, cause: Exception | None = None | ||
| ) -> Never: | ||
| """Raise a detailed assertion error.""" | ||
| __tracebackhide__ = True | ||
|
|
||
| trailing_left = "\n" if is_narwhals_series(left) else " " | ||
| trailing_right = "\n" if is_narwhals_series(right) else " " | ||
|
|
||
| msg = ( | ||
| f"{objects} are different ({detail})\n" | ||
| f"[left]:{trailing_left}{left}\n" | ||
| f"[right]:{trailing_right}{right}" | ||
| ) | ||
| raise AssertionError(msg) from cause | ||
|
|
||
|
|
||
| def raise_series_assertion_error( | ||
| detail: SeriesDetail, left: Any, right: Any, *, cause: Exception | None = None | ||
| ) -> Never: | ||
| raise_assertion_error("Series", detail, left, right, cause=cause) | ||
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.