Skip to content

Commit 9c866c6

Browse files
committed
Update function to round_and_hash_float_array and add test
1 parent 2498716 commit 9c866c6

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

tests/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ def assert_valid_field_data(data: xr.DataArray, grid: XGrid):
140140
assert ax_actual == ax_expected, f"Expected axis {ax_expected} for dimension '{dim}', got {ax_actual}"
141141

142142

143-
def hash_float_array(arr):
143+
def round_and_hash_float_array(arr, decimals=6):
144+
arr = np.round(arr, decimals=decimals)
145+
144146
# Adapted from https://cs.stackexchange.com/a/37965
145147
h = 1
146148
for f in arr:

tests/v4/test_particleset_execute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def test_uxstommelgyre_pset_execute():
155155
dt=np.timedelta64(60, "s"),
156156
pyfunc=AdvectionEE,
157157
)
158-
assert utils.hash_float_array([p.lon for p in pset]) == 1165396086
159-
assert utils.hash_float_array([p.lat for p in pset]) == 1142124776
158+
assert utils.round_and_hash_float_array([p.lon for p in pset]) == 1165396086
159+
assert utils.round_and_hash_float_array([p.lat for p in pset]) == 1142124776
160160

161161

162162
@pytest.mark.xfail(reason="Output file not implemented yet")

tests/v4/test_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import numpy as np
2+
3+
from tests import utils
4+
5+
6+
def test_round_and_hash_float_array():
7+
decimals = 7
8+
arr = np.array([1.0, 2.0, 3.0], dtype=np.float64)
9+
h = utils.round_and_hash_float_array(arr, decimals=decimals)
10+
assert h == 1068792616613
11+
12+
delta = 10**-decimals
13+
arr_test = arr + 0.49 * delta
14+
h2 = utils.round_and_hash_float_array(arr_test, decimals=decimals)
15+
assert h2 == h
16+
17+
arr_test = arr + 0.51 * delta
18+
h3 = utils.round_and_hash_float_array(arr_test, decimals=decimals)
19+
assert h3 != h

0 commit comments

Comments
 (0)