-
Notifications
You must be signed in to change notification settings - Fork 46
Open
Description
The test utility np.testing.assert_equal
will treat NaN as equal values. However this is not the case for some ml_dtypes arrays:
import ml_dtypes
import numpy as np
# This will succeed
fp32_array = np.array(np.nan, dtype=np.float32)
np.testing.assert_equal(fp32_array, fp32_array)
# This will fail
array = np.array(np.nan, dtype=ml_dtypes.bfloat16)
np.testing.assert_equal(array, array)
with
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "numpy/testing/_private/utils.py", line 371, in assert_equal
return assert_array_equal(actual, desired, err_msg, verbose,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "numpy/testing/_private/utils.py", line 1051, in assert_array_equal
assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg,
File "numpy/testing/_private/utils.py", line 916, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Arrays are not equal
Mismatched elements: 1 / 1 (100%)
Max absolute difference among violations: nan
Max relative difference among violations: nan
ACTUAL: array(nan, dtype=bfloat16)
DESIRED: array(nan, dtype=bfloat16)
Metadata
Metadata
Assignees
Labels
No labels