From e15518f047cf49b44f7f1bf594122f2cc7936893 Mon Sep 17 00:00:00 2001 From: Topher Cawlfield <4094385+tcawlfield@users.noreply.github.com> Date: Wed, 14 Aug 2024 16:01:29 -0600 Subject: [PATCH] feat: Add ak.array_equal (#3215) * Preparatory refactor: create ak_almost_equal._impl * Adding ak.array_equal Includes a very minimal test case. Needs more. * Fixing bug in nplike, array_equal The NaN values were not compared correctly. Also adding several tests of ak.array_equal. Also firing a minor issue in ak.array_equal. * Fixing array_equal docstring * Update src/awkward/operations/ak_almost_equal.py Use is operator to compare classes Co-authored-by: Jim Pivarski * Remove defaults from args to ak_almost_equal._impl * Fixing more bugs * Possible fix for old numpy * style: pre-commit fixes --------- Co-authored-by: Jim Pivarski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/awkward/_nplikes/array_module.py | 5 +- src/awkward/operations/__init__.py | 1 + src/awkward/operations/ak_almost_equal.py | 46 +++++++++++- src/awkward/operations/ak_array_equal.py | 54 ++++++++++++++ tests/test_1105_ak_aray_equal.py | 90 +++++++++++++++++++++++ 5 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 src/awkward/operations/ak_array_equal.py create mode 100644 tests/test_1105_ak_aray_equal.py diff --git a/src/awkward/_nplikes/array_module.py b/src/awkward/_nplikes/array_module.py index b474b6571c..5217ace411 100644 --- a/src/awkward/_nplikes/array_module.py +++ b/src/awkward/_nplikes/array_module.py @@ -174,7 +174,10 @@ def array_equal( assert not isinstance(x1, PlaceholderArray) assert not isinstance(x2, PlaceholderArray) if equal_nan: - both_nan = self._module.logical_and(x1 == np.nan, x2 == np.nan) + # Only newer numpy.array_equal supports the equal_nan parameter. + both_nan = self._module.logical_and( + self._module.isnan(x1), self._module.isnan(x2) + ) both_equal = x1 == x2 return self._module.all(self._module.logical_or(both_equal, both_nan)) else: diff --git a/src/awkward/operations/__init__.py b/src/awkward/operations/__init__.py index 90d4f99162..d0cee81508 100644 --- a/src/awkward/operations/__init__.py +++ b/src/awkward/operations/__init__.py @@ -12,6 +12,7 @@ from awkward.operations.ak_argmax import * from awkward.operations.ak_argmin import * from awkward.operations.ak_argsort import * +from awkward.operations.ak_array_equal import * from awkward.operations.ak_backend import * from awkward.operations.ak_broadcast_arrays import * from awkward.operations.ak_broadcast_fields import * diff --git a/src/awkward/operations/ak_almost_equal.py b/src/awkward/operations/ak_almost_equal.py index 6f54d3a79b..78461de65c 100644 --- a/src/awkward/operations/ak_almost_equal.py +++ b/src/awkward/operations/ak_almost_equal.py @@ -52,6 +52,32 @@ def almost_equal( # Dispatch yield left, right + return _impl( + left, + right, + rtol=rtol, + atol=atol, + dtype_exact=dtype_exact, + check_parameters=check_parameters, + check_regular=check_regular, + exact_eq=False, + same_content_types=False, + equal_nan=False, + ) + + +def _impl( + left, + right, + rtol: float, + atol: float, + dtype_exact: bool, + check_parameters: bool, + check_regular: bool, + exact_eq: bool, + same_content_types: bool, + equal_nan: bool, +): # Implementation left_behavior = behavior_of(left) right_behavior = behavior_of(right) @@ -82,6 +108,10 @@ def packed_list_content(layout): return layout.content[layout.offsets[0] : layout.offsets[-1]] def visitor(left, right) -> bool: + # Most firstly, check same_content_types before any transformations + if same_content_types and left.__class__ is not right.__class__: + return False + # First, erase indexed types! if left.is_indexed and not left.is_option: left = left.project() @@ -152,12 +182,26 @@ def visitor(left, right) -> bool: and backend.nplike.all(left.data == right.data) and left.shape == right.shape ) + elif exact_eq: + return ( + is_approx_dtype(left.dtype, right.dtype) + and backend.nplike.array_equal( + left.data, + right.data, + equal_nan=equal_nan, + ) + and left.shape == right.shape + ) else: return ( is_approx_dtype(left.dtype, right.dtype) and backend.nplike.all( backend.nplike.isclose( - left.data, right.data, rtol=rtol, atol=atol, equal_nan=False + left.data, + right.data, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, ) ) and left.shape == right.shape diff --git a/src/awkward/operations/ak_array_equal.py b/src/awkward/operations/ak_array_equal.py new file mode 100644 index 0000000000..2a7221baab --- /dev/null +++ b/src/awkward/operations/ak_array_equal.py @@ -0,0 +1,54 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import awkward as ak +from awkward._dispatch import high_level_function + +__all__ = ("array_equal",) + + +@high_level_function() +def array_equal( + a1, + a2, + equal_nan: bool = False, + dtype_exact: bool = True, + same_content_types: bool = True, + check_parameters: bool = True, + check_regular: bool = True, +): + """ + True if two arrays have the same shape and elements, False otherwise. + + Args: + a1: Array-like data (anything #ak.to_layout recognizes). + a2: Array-like data (anything #ak.to_layout recognizes). + equal_nan: bool (default=False) + Whether to count NaN values as equal to each other. + dtype_exact: bool (default=True) whether the dtypes must be exactly the same, or just the + same family. + same_content_types: bool (default=True) + Whether to require all content classes to match + check_parameters: bool (default=True) whether to compare parameters. + check_regular: bool (default=True) whether to consider ragged and regular dimensions as + unequal. + + TypeTracer arrays are not supported, as there is very little information to + be compared. + """ + # Dispatch + yield a1, a2 + + return ak.operations.ak_almost_equal._impl( + a1, + a2, + rtol=0.0, + atol=0.0, + dtype_exact=dtype_exact, + check_parameters=check_parameters, + check_regular=check_regular, + exact_eq=True, + same_content_types=same_content_types and check_regular, + equal_nan=equal_nan, + ) diff --git a/tests/test_1105_ak_aray_equal.py b/tests/test_1105_ak_aray_equal.py new file mode 100644 index 0000000000..519d0039c5 --- /dev/null +++ b/tests/test_1105_ak_aray_equal.py @@ -0,0 +1,90 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import numpy as np + +import awkward as ak +from awkward.contents import NumpyArray +from awkward.index import Index64 + + +def test_array_equal_simple(): + assert ak.array_equal( + ak.Array([[1, 2], [], [3, 4, 5]]), + ak.Array([[1, 2], [], [3, 4, 5]]), + ) + + +def test_array_equal_mixed_dtype(): + assert not ak.array_equal( + ak.Array(np.array([1.5, 2.0, 3.25], dtype=np.float32)), + ak.Array(np.array([1.5, 2.0, 3.25], dtype=np.float64)), + ) + assert ak.array_equal( + ak.Array(np.array([1.5, 2.0, 3.25], dtype=np.float32)), + ak.Array(np.array([1.5, 2.0, 3.25], dtype=np.float64)), + dtype_exact=False, + ) + + +def test_array_equal_on_listoffsets(): + a1 = ak.contents.ListOffsetArray( + Index64(np.array([0, 3, 3, 5])), NumpyArray(np.arange(5) * 1.5) + ) + a2 = ak.contents.ListOffsetArray( + Index64(np.array([0, 3, 3, 5])), + NumpyArray(np.arange(10) * 1.5), # Longer array content than a1 + ) + assert ak.array_equal(a1, a2) + + +def test_array_equal_mixed_content_type(): + a1 = ak.Array([[1, 2, 3], [4, 5, 6], [7, 8]]) + a1r = ak.to_regular(a1[:2]) + assert not ak.array_equal(a1[:2], a1r) + assert ak.array_equal(a1[:2], a1r, check_regular=False) + assert not ak.array_equal(a1, a1r, check_regular=False) + + assert ak.array_equal( + a1, a1.layout + ) # high-level automatically converted to layout in pre-check + + a2_np = ak.contents.NumpyArray(np.arange(3)) + a2_ia = ak.contents.IndexedArray( + Index64(np.array([0, 1, 2])), NumpyArray(np.arange(3)) + ) + assert ak.array_equal(a2_np, a2_ia, same_content_types=False) + + +def test_array_equal_opion_types(): + a1 = ak.Array([1, 2, None, 4]) + a2 = ak.concatenate([ak.Array([1, 2]), ak.Array([None, 4])]) + assert ak.array_equal(a1, a2) + + a3 = a1.layout.to_ByteMaskedArray(valid_when=True) + assert not ak.array_equal(a1, a3, same_content_types=True) + assert ak.array_equal(a1, a3, same_content_types=False) + assert not ak.array_equal( + a1, ak.Array([1, 2, 3, 4]), same_content_types=False, dtype_exact=False + ) + + +def test_array_equal_nan(): + a = ak.Array([1.0, 2.5, np.nan]) + nplike = a.layout.backend.nplike + assert not nplike.array_equal(a.layout.data, a.layout.data) + assert nplike.array_equal(a.layout.data, a.layout.data, equal_nan=True) + assert not ak.array_equal(a, a) + assert ak.array_equal(a, a, equal_nan=True) + + +def test_array_equal_with_params(): + a1 = NumpyArray( + np.array([1, 2, 3], dtype=np.uint32), parameters={"foo": {"bar": "baz"}} + ) + a2 = NumpyArray( + np.array([1, 2, 3], dtype=np.uint32), parameters={"foo": {"bar": "Not so fast"}} + ) + assert not ak.array_equal(a1, a2) + assert ak.array_equal(a1, a2, check_parameters=False)