Skip to content

Commit 6c9775f

Browse files
author
The ml_dtypes Authors
committed
Merge pull request #172 from jakevdp:divmod
PiperOrigin-RevId: 666881270
2 parents 30f2497 + 2975e8e commit 6c9775f

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2525

2626
* Added new 8-bit float type following IEEE 754 convention:
2727
`ml_dtypes.float8_e4m3`.
28+
* Fix outputs of float `divmod` and `floor_divide` when denominator is zero.
2829

2930
## [0.4.0] - 2024-04-1
3031

ml_dtypes/_src/ufuncs.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,13 @@ struct TrueDivide {
168168
inline std::pair<float, float> divmod(float a, float b) {
169169
if (b == 0.0f) {
170170
float nan = std::numeric_limits<float>::quiet_NaN();
171-
return {nan, nan};
171+
float inf = std::numeric_limits<float>::infinity();
172+
173+
if (std::isnan(a) || (a == 0.0f)) {
174+
return {nan, nan};
175+
} else {
176+
return {std::signbit(a) == std::signbit(b) ? inf : -inf, nan};
177+
}
172178
}
173179
float mod = std::fmod(a, b);
174180
float div = (a - mod) / b;

ml_dtypes/tests/custom_float_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,47 @@ def testDivmod(self, float_type):
830830
float_type=float_type,
831831
)
832832

833+
@ignore_warning(category=RuntimeWarning, message="invalid value encountered")
834+
@ignore_warning(category=RuntimeWarning, message="divide by zero encountered")
835+
def testDivmodCornerCases(self, float_type):
836+
x = np.array(
837+
[-np.nan, -np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan],
838+
dtype=float_type,
839+
)
840+
xf32 = x.astype("float32")
841+
out = np.divmod.outer(x, x)
842+
expected = np.divmod.outer(xf32, xf32)
843+
numpy_assert_allclose(
844+
out[0],
845+
truncate(expected[0], float_type=float_type),
846+
rtol=0.0,
847+
float_type=float_type,
848+
)
849+
numpy_assert_allclose(
850+
out[1],
851+
truncate(expected[1], float_type=float_type),
852+
rtol=0.0,
853+
float_type=float_type,
854+
)
855+
856+
@ignore_warning(category=RuntimeWarning, message="invalid value encountered")
857+
@ignore_warning(category=RuntimeWarning, message="divide by zero encountered")
858+
def testFloordivCornerCases(self, float_type):
859+
# Regression test for https://github.com/jax-ml/ml_dtypes/issues/170
860+
x = np.array(
861+
[-np.nan, -np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan],
862+
dtype=float_type,
863+
)
864+
xf32 = x.astype("float32")
865+
out = np.floor_divide.outer(x, x)
866+
expected = np.floor_divide.outer(xf32, xf32)
867+
numpy_assert_allclose(
868+
out,
869+
truncate(expected, float_type=float_type),
870+
rtol=0.0,
871+
float_type=float_type,
872+
)
873+
833874
def testModf(self, float_type):
834875
rng = np.random.RandomState(seed=42)
835876
x = rng.randn(3, 7).astype(float_type)

0 commit comments

Comments
 (0)