@@ -830,6 +830,47 @@ def testDivmod(self, float_type):
830830 float_type = float_type ,
831831 )
832832
833+ @ignore_warning (category = RuntimeWarning , message = "invalid value encountered" )
834+ @ignore_warning (category = RuntimeWarning , message = "divide by zero encountered" )
835+ def testDivmodCornerCases (self , float_type ):
836+ x = np .array (
837+ [- np .nan , - np .inf , - 1.0 , - 0.0 , 0.0 , 1.0 , np .inf , np .nan ],
838+ dtype = float_type ,
839+ )
840+ xf32 = x .astype ("float32" )
841+ out = np .divmod .outer (x , x )
842+ expected = np .divmod .outer (xf32 , xf32 )
843+ numpy_assert_allclose (
844+ out [0 ],
845+ truncate (expected [0 ], float_type = float_type ),
846+ rtol = 0.0 ,
847+ float_type = float_type ,
848+ )
849+ numpy_assert_allclose (
850+ out [1 ],
851+ truncate (expected [1 ], float_type = float_type ),
852+ rtol = 0.0 ,
853+ float_type = float_type ,
854+ )
855+
856+ @ignore_warning (category = RuntimeWarning , message = "invalid value encountered" )
857+ @ignore_warning (category = RuntimeWarning , message = "divide by zero encountered" )
858+ def testFloordivCornerCases (self , float_type ):
859+ # Regression test for https://github.com/jax-ml/ml_dtypes/issues/170
860+ x = np .array (
861+ [- np .nan , - np .inf , - 1.0 , - 0.0 , 0.0 , 1.0 , np .inf , np .nan ],
862+ dtype = float_type ,
863+ )
864+ xf32 = x .astype ("float32" )
865+ out = np .floor_divide .outer (x , x )
866+ expected = np .floor_divide .outer (xf32 , xf32 )
867+ numpy_assert_allclose (
868+ out ,
869+ truncate (expected , float_type = float_type ),
870+ rtol = 0.0 ,
871+ float_type = float_type ,
872+ )
873+
833874 def testModf (self , float_type ):
834875 rng = np .random .RandomState (seed = 42 )
835876 x = rng .randn (3 , 7 ).astype (float_type )
0 commit comments