diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 93b6d4f17..5af59225f 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -10,7 +10,6 @@ import jaxsim.api as js import jaxsim.typing as jtp -from jaxsim.math import Quaternion from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability try: @@ -548,17 +547,11 @@ def integrate_rk_stage( op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf xf: js.ode_data.ODEState = jax.tree_util.tree_map(op, x0, k) - W_Q_B_t0 = x0.physics_model.base_quaternion - W_ω_WB_t0 = x0.physics_model.base_angular_velocity + W_Q_B_tf = xf.physics_model.base_quaternion return xf.replace( physics_model=xf.physics_model.replace( - base_quaternion=Quaternion.integration( - quaternion=W_Q_B_t0, - dt=dt, - omega=W_ω_WB_t0, - omega_in_body_fixed=False, - ), + base_quaternion=W_Q_B_tf / jnp.linalg.norm(W_Q_B_tf) ) )