Skip to content

Commit

Permalink
Merge pull request #183 from ami-iit/allow_overriding_model_and_data_…
Browse files Browse the repository at this point in the history
…in_integrator_kwargs

Always forward to the integrator the `model` and `data` passed to `jaxsim.api.model.step`
  • Loading branch information
diegoferigo committed Jun 20, 2024
2 parents 3359e9a + 7250666 commit 0cfd5f7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
23 changes: 21 additions & 2 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,13 @@ def step(
and the new state of the integrator.
"""

integrator_kwargs = kwargs if kwargs is not None else dict()
# Extract the integrator kwargs.
# The following logic allows using integrators having kwargs colliding with the
# kwargs of this step function.
kwargs = kwargs if kwargs is not None else {}
integrator_kwargs = kwargs.pop("integrator_kwargs", {})
integrator_kwargs = kwargs | integrator_kwargs

integrator_state = integrator_state if integrator_state is not None else dict()

# Extract the initial resources.
Expand All @@ -1855,8 +1861,21 @@ def step(
t0=jnp.array(t0_ns / 1e9).astype(float),
dt=dt,
params=integrator_state_x0,
# Always inject the current (model, data) pair into the system dynamics
# considered by the integrator, and include the input variables represented
# by the pair (joint_forces, link_forces).
# Note that the wrapper of the system dynamics will override (state_x0, t0)
# inside the passed data even if it is not strictly needed. This logic is
# necessary to re-use the jit-compiled step function of compatible pytrees
# of model and data produced e.g. by parameterized applications.
**(
dict(joint_forces=joint_forces, link_forces=link_forces) | integrator_kwargs
dict(
model=model,
data=data,
joint_forces=joint_forces,
link_forces=link_forces,
)
| integrator_kwargs
),
)

Expand Down
4 changes: 3 additions & 1 deletion src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,9 @@ def compute_ki() -> jax.Array:

# Update the FSAL property for the next iteration.
if self.has_fsal:
self.params["dxdt0"] = jax.tree_map(lambda l: l[self.index_of_fsal], K)
self.params["dxdt0"] = jax.tree_util.tree_map(
lambda l: l[self.index_of_fsal], K
)

# Compute the output state.
# Note that z contains as many new states as the rows of `b.T`.
Expand Down

0 comments on commit 0cfd5f7

Please sign in to comment.