Skip to content

Commit 3462f28

Browse files
committed
fix utils_test
1 parent a4be45d commit 3462f28

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def f(x: Any) -> jnp.ndarray:
5454
gt_fn = jax.vmap(f, in_axes=in_axes)
5555
fn = utils.batched_vmap(f, in_axes=in_axes, batch_size=2)
5656

57-
np.testing.assert_array_equal(gt_fn(x), fn(x), rtol=1e-5, atol=1e-5)
57+
np.testing.assert_array_almost_equal(gt_fn(x), fn(x), decimal=4)
5858

5959
@pytest.mark.parametrize("batch_size", [1, 7, 67, 133])
6060
@pytest.mark.parametrize("in_axes", [0, 1, -1, -2, [0, None], (0, -2)])

0 commit comments

Comments
 (0)