From 2dc884f395db29bd628af79204201a3f96d95a88 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 19 Jun 2024 11:40:30 +0200 Subject: [PATCH 1/2] Always forward to the integrator the model and data of jaxsim.model.step --- src/jaxsim/api/model.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 10e420129..5bb45fa9d 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1841,7 +1841,13 @@ def step( and the new state of the integrator. """ - integrator_kwargs = kwargs if kwargs is not None else dict() + # Extract the integrator kwargs. + # The following logic allows using integrators having kwargs colliding with the + # kwargs of this step function. + kwargs = kwargs if kwargs is not None else {} + integrator_kwargs = kwargs.pop("integrator_kwargs", {}) + integrator_kwargs = kwargs | integrator_kwargs + integrator_state = integrator_state if integrator_state is not None else dict() # Extract the initial resources. @@ -1855,8 +1861,21 @@ def step( t0=jnp.array(t0_ns / 1e9).astype(float), dt=dt, params=integrator_state_x0, + # Always inject the current (model, data) pair into the system dynamics + # considered by the integrator, and include the input variables represented + # by the pair (joint_forces, link_forces). + # Note that the wrapper of the system dynamics will override (state_x0, t0) + # inside the passed data even if it is not strictly needed. This logic is + # necessary to re-use the jit-compiled step function of compatible pytrees + # of model and data produced e.g. by parameterized applications. **( - dict(joint_forces=joint_forces, link_forces=link_forces) | integrator_kwargs + dict( + model=model, + data=data, + joint_forces=joint_forces, + link_forces=link_forces, + ) + | integrator_kwargs ), ) From 72506664805a2e62c98445e4f3a8bde6dbba84ca Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 19 Jun 2024 12:00:28 +0200 Subject: [PATCH 2/2] Address deprecation warning --- src/jaxsim/integrators/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 26e9cb47a..93b6d4f17 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -422,7 +422,9 @@ def compute_ki() -> jax.Array: # Update the FSAL property for the next iteration. if self.has_fsal: - self.params["dxdt0"] = jax.tree_map(lambda l: l[self.index_of_fsal], K) + self.params["dxdt0"] = jax.tree_util.tree_map( + lambda l: l[self.index_of_fsal], K + ) # Compute the output state. # Note that z contains as many new states as the rows of `b.T`.