@@ -3571,9 +3571,37 @@ def test_split(self):
35713571 self .assertEqual (len (knp .Split (2 )(x )), 2 )
35723572
35733573 def test_sqrt (self ):
3574- x = np .array ([[1 , 4 , 9 ], [16 , 25 , 36 ]])
3575- self .assertAllClose (knp .sqrt (x ), np .sqrt (x ))
3576- self .assertAllClose (knp .Sqrt ()(x ), np .sqrt (x ))
3574+ x = np .array ([[1 , 4 , 9 ], [16 , 25 , 36 ]], dtype = "float32" )
3575+ ref_y = np .sqrt (x )
3576+ y = knp .sqrt (x )
3577+ self .assertEqual (standardize_dtype (y .dtype ), "float32" )
3578+ self .assertAllClose (y , ref_y )
3579+ y = knp .Sqrt ()(x )
3580+ self .assertEqual (standardize_dtype (y .dtype ), "float32" )
3581+ self .assertAllClose (y , ref_y )
3582+
3583+ @pytest .mark .skipif (
3584+ backend .backend () == "jax" , reason = "JAX does not support float64."
3585+ )
3586+ def test_sqrt_float64 (self ):
3587+ x = np .array ([[1 , 4 , 9 ], [16 , 25 , 36 ]], dtype = "float64" )
3588+ ref_y = np .sqrt (x )
3589+ y = knp .sqrt (x )
3590+ self .assertEqual (standardize_dtype (y .dtype ), "float64" )
3591+ self .assertAllClose (y , ref_y )
3592+ y = knp .Sqrt ()(x )
3593+ self .assertEqual (standardize_dtype (y .dtype ), "float64" )
3594+ self .assertAllClose (y , ref_y )
3595+
3596+ def test_sqrt_int32 (self ):
3597+ x = np .array ([[1 , 4 , 9 ], [16 , 25 , 36 ]], dtype = "int32" )
3598+ ref_y = np .sqrt (x )
3599+ y = knp .sqrt (x )
3600+ self .assertEqual (standardize_dtype (y .dtype ), "float32" )
3601+ self .assertAllClose (y , ref_y )
3602+ y = knp .Sqrt ()(x )
3603+ self .assertEqual (standardize_dtype (y .dtype ), "float32" )
3604+ self .assertAllClose (y , ref_y )
35773605
35783606 def test_stack (self ):
35793607 x = np .array ([[1 , 2 , 3 ], [3 , 2 , 1 ]])
@@ -3704,6 +3732,8 @@ def test_arange(self):
37043732 self .assertAllClose (knp .Arange ()(3 , 7 ), np .arange (3 , 7 ))
37053733 self .assertAllClose (knp .Arange ()(3 , 7 , 2 ), np .arange (3 , 7 , 2 ))
37063734
3735+ self .assertEqual (standardize_dtype (knp .arange (3 ).dtype ), "int32" )
3736+
37073737 def test_full (self ):
37083738 self .assertAllClose (knp .full ([2 , 3 ], 0 ), np .full ([2 , 3 ], 0 ))
37093739 self .assertAllClose (knp .full ([2 , 3 ], 0.1 ), np .full ([2 , 3 ], 0.1 ))
0 commit comments