Skip to content

NaN comparison fails in np.testing.assert_equal #301

@justinchuby

Description

@justinchuby

The test utility np.testing.assert_equal will treat NaN as equal values. However this is not the case for some ml_dtypes arrays:

import ml_dtypes
import numpy as np

# This will succeed 
fp32_array = np.array(np.nan, dtype=np.float32)
np.testing.assert_equal(fp32_array, fp32_array)

# This will fail
array = np.array(np.nan, dtype=ml_dtypes.bfloat16)
np.testing.assert_equal(array, array)

with

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "numpy/testing/_private/utils.py", line 371, in assert_equal
    return assert_array_equal(actual, desired, err_msg, verbose,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "numpy/testing/_private/utils.py", line 1051, in assert_array_equal
    assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg,
  File "numpy/testing/_private/utils.py", line 916, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal

Mismatched elements: 1 / 1 (100%)
Max absolute difference among violations: nan
Max relative difference among violations: nan
 ACTUAL: array(nan, dtype=bfloat16)
 DESIRED: array(nan, dtype=bfloat16)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions