Skip to content

Commit

Permalink
Merge pull request #26 from anh-tong/fix-lower-precision-test
Browse files Browse the repository at this point in the history
Fix comparison between signatory and signax output
  • Loading branch information
anh-tong authored May 30, 2023
2 parents 9c3be9d + ae5c465 commit b192269
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
rng = default_rng()

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)


def test_signature_1d_path():
Expand All @@ -39,7 +40,7 @@ def test_multi_signature_combine():
jax_signatures = [jnp.array(x) for x in signatures]

jax_output = multi_signature_combine(jax_signatures)
jax_sum = sum(jnp.sum(x) for x in jax_output)
jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_output])

torch_signatures = []
for i in range(batch_size):
Expand All @@ -51,8 +52,8 @@ def test_multi_signature_combine():
torch_output = signatory.multi_signature_combine(
torch_signatures, input_channels=dim, depth=len(signatures)
)
torch_sum = torch_output.sum().item()
assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1)
torch_output = jnp.array(torch_output.numpy())
assert jnp.allclose(jax_output, torch_output)


def test_signature_batch():
Expand All @@ -65,25 +66,23 @@ def test_signature_batch():

path = rng.standard_normal((length, dim))
jax_signature = signature_batch(path, depth, n_chunks)
jax_sum = sum(jnp.sum(x) for x in jax_signature)
jax_signature = jnp.concatenate([jnp.ravel(x) for x in jax_signature])

torch_path = torch.tensor(path)
torch_signature = signatory.signature(torch_path[None, ...], depth=depth)
torch_sum = torch_signature.sum().item()
torch_signature = jnp.array(torch_signature.numpy())

# TODO: this has a low precision error
assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1)
assert jnp.allclose(jax_signature, torch_signature)

# has remainder case
length = 1005
path = rng.standard_normal((length, dim))

jax_signature = signature_batch(path, depth, n_chunks)
jax_sum = sum(jnp.sum(x) for x in jax_signature)
jax_signature = jnp.concatenate([jnp.ravel(x) for x in jax_signature])

torch_path = torch.tensor(path)
torch_signature = signatory.signature(torch_path[None, ...], depth=depth)
torch_sum = torch_signature.sum().item()
torch_signature = jnp.array(torch_signature.numpy())

# TODO: this has a low precision error
assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1)
assert jnp.allclose(jax_signature, torch_signature)

0 comments on commit b192269

Please sign in to comment.