We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a4be45d commit 3462f28Copy full SHA for 3462f28
tests/utils_test.py
@@ -54,7 +54,7 @@ def f(x: Any) -> jnp.ndarray:
54
gt_fn = jax.vmap(f, in_axes=in_axes)
55
fn = utils.batched_vmap(f, in_axes=in_axes, batch_size=2)
56
57
- np.testing.assert_array_equal(gt_fn(x), fn(x), rtol=1e-5, atol=1e-5)
+ np.testing.assert_array_almost_equal(gt_fn(x), fn(x), decimal=4)
58
59
@pytest.mark.parametrize("batch_size", [1, 7, 67, 133])
60
@pytest.mark.parametrize("in_axes", [0, 1, -1, -2, [0, None], (0, -2)])
0 commit comments