From ae5c4659303493834723897d8349473a0ab7d0b0 Mon Sep 17 00:00:00 2001 From: anh tong Date: Tue, 30 May 2023 19:31:16 +0900 Subject: [PATCH] fix comparison between signatory and signax output --- tests/test_signature.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/test_signature.py b/tests/test_signature.py index 4a2add0..11322a1 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -15,6 +15,7 @@ rng = default_rng() jax.config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) def test_signature_1d_path(): @@ -39,7 +40,7 @@ def test_multi_signature_combine(): jax_signatures = [jnp.array(x) for x in signatures] jax_output = multi_signature_combine(jax_signatures) - jax_sum = sum(jnp.sum(x) for x in jax_output) + jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_output]) torch_signatures = [] for i in range(batch_size): @@ -51,8 +52,8 @@ def test_multi_signature_combine(): torch_output = signatory.multi_signature_combine( torch_signatures, input_channels=dim, depth=len(signatures) ) - torch_sum = torch_output.sum().item() - assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1) + torch_output = jnp.array(torch_output.numpy()) + assert jnp.allclose(jax_output, torch_output) def test_signature_batch(): @@ -65,25 +66,23 @@ def test_signature_batch(): path = rng.standard_normal((length, dim)) jax_signature = signature_batch(path, depth, n_chunks) - jax_sum = sum(jnp.sum(x) for x in jax_signature) + jax_signature = jnp.concatenate([jnp.ravel(x) for x in jax_signature]) torch_path = torch.tensor(path) torch_signature = signatory.signature(torch_path[None, ...], depth=depth) - torch_sum = torch_signature.sum().item() + torch_signature = jnp.array(torch_signature.numpy()) - # TODO: this has a low precision error - assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1) + assert jnp.allclose(jax_signature, torch_signature) # has remainder case length = 1005 path = rng.standard_normal((length, dim)) jax_signature = signature_batch(path, depth, n_chunks) - jax_sum = sum(jnp.sum(x) for x in jax_signature) + jax_signature = jnp.concatenate([jnp.ravel(x) for x in jax_signature]) torch_path = torch.tensor(path) torch_signature = signatory.signature(torch_path[None, ...], depth=depth) - torch_sum = torch_signature.sum().item() + torch_signature = jnp.array(torch_signature.numpy()) - # TODO: this has a low precision error - assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1) + assert jnp.allclose(jax_signature, torch_signature)