diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 712b61441..8d063dc87 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -3,6 +3,7 @@ from contextlib import redirect_stdout import jax +import jax.numpy as jnp import jaxsim.api as js @@ -45,3 +46,21 @@ def test_call_jit_compiled_function_passing_different_objects( f"Compiling {js.contact.estimate_good_soft_contacts_parameters.__name__}" not in stdout ) + + # Define a new JIT-compiled function and check that is not recompiled for + # different model objects having the same pytree structure. + @jax.jit + def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData): + # Return random elements from model and data, just to have something returned. + return ( + jnp.sum(model.kin_dyn_parameters.link_parameters.mass), + data.base_position(), + ) + + data1 = js.data.JaxSimModelData.build(model=model1) + + _ = my_jit_function(model=model1, data=data1) + assert my_jit_function._cache_size() == 1 + + _ = my_jit_function(model=model2, data=data1) + assert my_jit_function._cache_size() == 1