Skip to content

Commit

Permalink
Extend pytree test
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 14, 2024
1 parent ec67648 commit b7ff488
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import redirect_stdout

import jax
import jax.numpy as jnp

import jaxsim.api as js

Expand Down Expand Up @@ -45,3 +46,21 @@ def test_call_jit_compiled_function_passing_different_objects(
f"Compiling {js.contact.estimate_good_soft_contacts_parameters.__name__}"
not in stdout
)

# Define a new JIT-compiled function and check that is not recompiled for
# different model objects having the same pytree structure.
@jax.jit
def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData):
# Return random elements from model and data, just to have something returned.
return (
jnp.sum(model.kin_dyn_parameters.link_parameters.mass),
data.base_position(),
)

data1 = js.data.JaxSimModelData.build(model=model1)

_ = my_jit_function(model=model1, data=data1)
assert my_jit_function._cache_size() == 1

_ = my_jit_function(model=model2, data=data1)
assert my_jit_function._cache_size() == 1

0 comments on commit b7ff488

Please sign in to comment.