Skip to content

Commit

Permalink
Always forward to the integrator the model and data of jaxsim.model.step
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 19, 2024
1 parent 5e4adaa commit 2dc884f
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,13 @@ def step(
and the new state of the integrator.
"""

integrator_kwargs = kwargs if kwargs is not None else dict()
# Extract the integrator kwargs.
# The following logic allows using integrators having kwargs colliding with the
# kwargs of this step function.
kwargs = kwargs if kwargs is not None else {}
integrator_kwargs = kwargs.pop("integrator_kwargs", {})
integrator_kwargs = kwargs | integrator_kwargs

integrator_state = integrator_state if integrator_state is not None else dict()

# Extract the initial resources.
Expand All @@ -1855,8 +1861,21 @@ def step(
t0=jnp.array(t0_ns / 1e9).astype(float),
dt=dt,
params=integrator_state_x0,
# Always inject the current (model, data) pair into the system dynamics
# considered by the integrator, and include the input variables represented
# by the pair (joint_forces, link_forces).
# Note that the wrapper of the system dynamics will override (state_x0, t0)
# inside the passed data even if it is not strictly needed. This logic is
# necessary to re-use the jit-compiled step function of compatible pytrees
# of model and data produced e.g. by parameterized applications.
**(
dict(joint_forces=joint_forces, link_forces=link_forces) | integrator_kwargs
dict(
model=model,
data=data,
joint_forces=joint_forces,
link_forces=link_forces,
)
| integrator_kwargs
),
)

Expand Down

0 comments on commit 2dc884f

Please sign in to comment.