diff --git a/src/jaxsim/simulation/integrators.py b/src/jaxsim/simulation/integrators.py index 2ce1c6e53..ca0e00930 100644 --- a/src/jaxsim/simulation/integrators.py +++ b/src/jaxsim/simulation/integrators.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Tuple, Union +import enum +from typing import Any, Callable import jax import jax.numpy as jnp @@ -19,30 +20,39 @@ StateDerivative = jtp.PyTree StateDerivativeCallable = Callable[ - [State, Time], Tuple[StateDerivative, Dict[str, Any]] + [State, Time], tuple[StateDerivative, dict[str, Any]] ] +class IntegratorType(enum.IntEnum): + RungeKutta4 = enum.auto() + EulerForward = enum.auto() + EulerSemiImplicit = enum.auto() + EulerSemiImplicitManifold = enum.auto() + + # ======================= # Single-step integration # ======================= -def odeint_euler_one_step( +def integrator_fixed_single_step( dx_dt: StateDerivativeCallable, - x0: State, + x0: State | ODEState, t0: Time, tf: Time, + integrator_type: IntegratorType, num_sub_steps: int = 1, -) -> Tuple[State, Dict[str, Any]]: +) -> tuple[State | ODEState, dict[str, Any]]: """ - Forward Euler integrator. + Advance a state vector by integrating a sytem dynamics with a fixed-step integrator. Args: dx_dt: Callable that computes the state derivative. x0: Initial state. t0: Initial time. tf: Final time. + integrator_type: Integrator type. num_sub_steps: Number of sub-steps to break the integration into. Returns: @@ -55,10 +65,14 @@ def odeint_euler_one_step( sub_step_dt = dt / num_sub_steps # Initialize the carry - Carry = Tuple[State, Time] + Carry = tuple[State | ODEState, Time] carry_init: Carry = (x0, t0) - def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: + def forward_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]: + """ + Forward Euler integrator. + """ + # Unpack the carry x_t0, t0 = carry @@ -78,48 +92,11 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: return carry, None - # Integrate over the given horizon - (x_tf, _), _ = jax.lax.scan( - f=body_fun, init=carry_init, xs=None, length=num_sub_steps - ) + def rk4_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]: + """ + Runge-Kutta 4 integrator. + """ - # Compute the aux dictionary at t0 - _, aux_t0 = dx_dt(x0, t0) - - return x_tf, aux_t0 - - -def odeint_rk4_one_step( - dx_dt: StateDerivativeCallable, - x0: State, - t0: Time, - tf: Time, - num_sub_steps: int = 1, -) -> Tuple[State, Dict[str, Any]]: - """ - Runge-Kutta 4 integrator. - - Args: - dx_dt: Callable that computes the state derivative. - x0: Initial state. - t0: Initial time. - tf: Final time. - num_sub_steps: Number of sub-steps to break the integration into. - - Returns: - The final state and a dictionary including auxiliary data at t0. - """ - - # Compute the sub-step size. - # We break dt in configurable sub-steps. - dt = tf - t0 - sub_step_dt = dt / num_sub_steps - - # Initialize the carry - Carry = Tuple[State, Time] - carry_init: Carry = (x0, t0) - - def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: # Unpack the carry x_t0, t0 = carry @@ -148,49 +125,11 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: return carry, None - # Integrate over the given horizon - (x_tf, _), _ = jax.lax.scan( - f=body_fun, init=carry_init, xs=None, length=num_sub_steps - ) - - # Compute the aux dictionary at t0 - _, aux_t0 = dx_dt(x0, t0) - - return x_tf, aux_t0 - - -def odeint_euler_semi_implicit_one_step( - dx_dt: StateDerivativeCallable, - x0: ODEState, - t0: Time, - tf: Time, - num_sub_steps: int = 1, -) -> Tuple[ODEState, Dict[str, Any]]: - """ - Semi-implicit Euler integrator. - - Args: - dx_dt: Callable that computes the state derivative. - x0: Initial state as ODEState object. - t0: Initial time. - tf: Final time. - num_sub_steps: Number of sub-steps to break the integration into. - - Returns: - A tuple having as first element the final state as ODEState object, - and as second element a dictionary including auxiliary data at t0. - """ + def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]: + """ + Semi-implicit Euler integrator. + """ - # Compute the sub-step size. - # We break dt in configurable sub-steps. - dt = tf - t0 - sub_step_dt = dt / num_sub_steps - - # Initialize the carry - Carry = Tuple[ODEState, Time] - carry_init: Carry = (x0, t0) - - def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: # Unpack the carry x_t0, t0 = carry @@ -218,6 +157,7 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: # 2. Compute the derivative of the generalized position # 3. Integrate the implicit velocities # 4. Integrate the remaining state + # 5. Outside the loop: integrate the quaternion on SO(3) manifold # ---------------------------------------------------------------- # 1. Integrate the accelerations obtaining the implicit velocities @@ -254,13 +194,27 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3] # Compute the derivative of the generalized position - d_pos_tf = jnp.hstack([BW_vl_WB, W_Qd_B, vel_tf[6:]]) + d_pos_tf = ( + jnp.hstack([BW_vl_WB, vel_tf[6:]]) + if integrator_type is IntegratorType.EulerSemiImplicitManifold + else jnp.hstack([BW_vl_WB, W_Qd_B, vel_tf[6:]]) + ) # ------------------------------------ # 3. Integrate the implicit velocities # ------------------------------------ pos_tf = pos_t0 + sub_step_dt * d_pos_tf + joint_positions = ( + pos_tf[3:] + if integrator_type is IntegratorType.EulerSemiImplicitManifold + else pos_tf[7:] + ) + base_quaternion = ( + jnp.zeros_like(x_t0.base_quaternion) + if integrator_type is IntegratorType.EulerSemiImplicitManifold + else pos_tf[3:7] + ) # --------------------------------- # 4. Integrate the remaining state @@ -275,8 +229,8 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: x_tf = ODEState( physics_model=PhysicsModelState( base_position=pos_tf[0:3], - base_quaternion=pos_tf[3:7], - joint_positions=pos_tf[7:], + base_quaternion=base_quaternion, + joint_positions=joint_positions, base_linear_velocity=vel_tf[0:3], base_angular_velocity=vel_tf[3:6], joint_velocities=vel_tf[6:], @@ -294,176 +248,43 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: return carry, None - # Integrate over the given horizon - (x_tf, _), _ = jax.lax.scan( - f=body_fun, init=carry_init, xs=None, length=num_sub_steps - ) - - # Compute the aux dictionary at t0 - _, aux_t0 = dx_dt(x0, t0) - - return x_tf, aux_t0 - - -def odeint_euler_semi_implicit_manifold_one_step( - dx_dt: StateDerivativeCallable, - x0: ODEState, - t0: Time, - tf: Time, - num_sub_steps: int = 1, -) -> Tuple[ODEState, Dict[str, Any]]: - """ - Semi-implicit Euler integrator with quaternion integration on SO(3). - - Args: - dx_dt: Callable that computes the state derivative. - x0: Initial state as ODEState object. - t0: Initial time. - tf: Final time. - num_sub_steps: Number of sub-steps to break the integration into. - - Returns: - A tuple having as first element the final state as ODEState object, - and as second element a dictionary including auxiliary data at t0. - """ - - # Compute the sub-step size. - # We break dt in configurable sub-steps. - dt = tf - t0 - sub_step_dt = dt / num_sub_steps - - # Integrate the quaternion on its manifold using the new angular velocity - - # Initialize the carry - Carry = Tuple[ODEState, Time] - carry_init: Carry = (x0, t0) - - def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: - # Unpack the carry - x_t0, t0 = carry - - # Compute the state derivative. - # We only keep the quantities related to the acceleration and discard those - # related to the velocity since we are going to use those implicitly integrated - # from the accelerations. - StateDerivative = ODEState - dxdt_t0: StateDerivative = dx_dt(x_t0, t0)[0] - - # Extract the initial position ∈ ℝ⁷⁺ⁿ and initial velocity ∈ ℝ⁶⁺ⁿ. - # This integrator, contrarily to most of the other ones, is not generic. - # It expects to operate on an x object of class ODEState. - pos_t0 = x_t0.physics_model.position() - vel_t0 = x_t0.physics_model.velocity() - - # Extract the velocity derivative - d_vel_dt = dxdt_t0.physics_model.velocity() - - # ============================================= - # Perform semi-implicit Euler integration [1-4] - # ============================================= + _integrator_registry = { + IntegratorType.RungeKutta4: rk4_body_fun, + IntegratorType.EulerForward: forward_euler_body_fun, + IntegratorType.EulerSemiImplicit: semi_implicit_euler_body_fun, + IntegratorType.EulerSemiImplicitManifold: semi_implicit_euler_body_fun, + } - # 1. Integrate the accelerations obtaining the implicit velocities - # 2. Compute the derivative of the generalized position (w/o quaternion) - # 3. Integrate the implicit velocities (w/o quaternion) - # 4. Integrate the remaining state - # 5. Outside the loop: integrate the quaternion on SO(3) manifold - - # ---------------------------------------------------------------- - # 1. Integrate the accelerations obtaining the implicit velocities - # ---------------------------------------------------------------- - - vel_tf = vel_t0 + sub_step_dt * d_vel_dt - - # ---------------------------------------------------------------------- - # 2. Compute the derivative of the generalized position (w/o quaternion) - # ---------------------------------------------------------------------- - - # Compute the transform of the mixed base frame at t0 - W_H_BW = jnp.vstack( - [ - jnp.block([jnp.eye(3), jnp.vstack(x_t0.physics_model.base_position)]), - jnp.array([0, 0, 0, 1]), - ] - ) - - # The derivative W_ṗ_B of the base position is the linear component of the - # mixed velocity B[W]_v_WB. We need to compute it from the velocity in - # inertial-fixed representation W_vl_WB. - W_v_WB = vel_tf[0:6] - BW_Xv_W = se3.SE3.from_matrix(W_H_BW).inverse().adjoint() - BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3] - - # Compute the derivative of the generalized position excluding the quaternion - pos_no_quat_t0 = jnp.hstack([pos_t0[0:3], pos_t0[7:]]) - d_pos_no_quat_tf = jnp.hstack([BW_vl_WB, vel_tf[6:]]) - - # ----------------------------------------------------- - # 3. Integrate the implicit velocities (w/o quaternion) - # ----------------------------------------------------- - - pos_no_quat_tf = pos_no_quat_t0 + sub_step_dt * d_pos_no_quat_tf - - # --------------------------------- - # 4. Integrate the remaining state - # --------------------------------- - - # Integrate the derivative of the tangential material deformation - m = x_t0.soft_contacts.tangential_deformation - ṁ = dxdt_t0.soft_contacts.tangential_deformation - tangential_deformation_tf = m + sub_step_dt * ṁ - - # Pack the new state into an ODEState object. - # We store a zero quaternion as placeholder, it will be replaced later. - x_tf = ODEState( - physics_model=PhysicsModelState( - base_position=pos_no_quat_tf[0:3], - base_quaternion=jnp.zeros_like(x_t0.physics_model.base_quaternion), - joint_positions=pos_no_quat_tf[3:], - base_linear_velocity=vel_tf[0:3], - base_angular_velocity=vel_tf[3:6], - joint_velocities=vel_tf[6:], - ), - soft_contacts=SoftContactsState( - tangential_deformation=tangential_deformation_tf - ), - ) - - # Update the time - tf = t0 + sub_step_dt - - # Pack the carry - carry = (x_tf, tf) - - return carry, None + # Get the body function for the selected integrator + body_fun = _integrator_registry[integrator_type] # Integrate over the given horizon - (x_no_quat_tf, _), _ = jax.lax.scan( + (x_tf, _), _ = jax.lax.scan( f=body_fun, init=carry_init, xs=None, length=num_sub_steps ) - # --------------------------------------------- - # 5. Integrate the quaternion on SO(3) manifold - # --------------------------------------------- - - # Indices to convert quaternions between serializations - to_xyzw = jnp.array([1, 2, 3, 0]) - to_wxyz = jnp.array([3, 0, 1, 2]) + if integrator_type is IntegratorType.EulerSemiImplicitManifold: + # Indices to convert quaternions between serializations + to_xyzw = jnp.array([1, 2, 3, 0]) + to_wxyz = jnp.array([3, 0, 1, 2]) - # Get the initial quaternion and the implicitly integrated angular velocity - W_ω_WB_tf = x_no_quat_tf.physics_model.base_angular_velocity - W_Q_B_t0 = so3.SO3.from_quaternion_xyzw(x0.physics_model.base_quaternion[to_xyzw]) + # Get the initial quaternion and the implicitly integrated angular velocity + W_ω_WB_tf = x_tf.physics_model.base_angular_velocity + W_Q_B_t0 = so3.SO3.from_quaternion_xyzw( + x0.physics_model.base_quaternion[to_xyzw] + ) - # Integrate the quaternion on its manifold using the implicit angular velocity, - # transformed in body-fixed representation since jaxlie uses this convention - B_R_W = W_Q_B_t0.inverse().as_matrix() - W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf) + # Integrate the quaternion on its manifold using the implicit angular velocity, + # transformed in body-fixed representation since jaxlie uses this convention + B_R_W = W_Q_B_t0.inverse().as_matrix() + W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf) - # Store the quaternion in the final state - x_tf = x_no_quat_tf.replace( - physics_model=x_no_quat_tf.physics_model.replace( - base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz] + # Store the quaternion in the final state + x_tf = x_tf.replace( + physics_model=x_tf.physics_model.replace( + base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz] + ) ) - ) # Compute the aux dictionary at t0 _, aux_t0 = dx_dt(x0, t0) @@ -477,10 +298,10 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]: def integrate_single_step_over_horizon( - integrator_single_step: Callable[[Time, Time, State], Tuple[State, Dict[str, Any]]], + integrator_single_step: Callable[[Time, Time, State], tuple[State, dict[str, Any]]], t: TimeHorizon, x0: State, -) -> Tuple[State, Dict[str, Any]]: +) -> tuple[State, dict[str, Any]]: """ Integrate a single-step integrator over a given horizon. @@ -496,7 +317,7 @@ def integrate_single_step_over_horizon( # Initialize the carry carry_init = (x0, t) - def body_fun(carry: Tuple, idx: int) -> Tuple[Tuple, jtp.PyTree]: + def body_fun(carry: tuple, idx: int) -> tuple[tuple, jtp.PyTree]: # Unpack the carry x_t0, horizon = carry @@ -526,96 +347,17 @@ def body_fun(carry: Tuple, idx: int) -> Tuple[Tuple, jtp.PyTree]: # =================================================================== -def odeint_euler( - func, - y0: State, - t: TimeHorizon, - *args, - num_sub_steps: int = 1, - return_aux: bool = False -) -> Union[State, Tuple[State, Dict[str, Any]]]: - """ - Integrate a system of ODEs using the Euler method. - - Args: - func: A function that computes the time-derivative of the state. - y0: The initial state. - t: The vector of time instants of the integration horizon. - *args: Additional arguments to be passed to the function func. - num_sub_steps: The number of sub-steps to be performed within each integration step. - return_aux: Whether to return the auxiliary data produced by the integrator. - - Returns: - The state of the system at the end of the integration horizon, and optionally - the auxiliary data produced by the integrator. - """ - - # Close func over additional inputs and parameters - dx_dt_closure_aux = lambda x, ts: func(x, ts, *args) - - # Close one-step integration over its arguments - integrator_single_step = lambda t0, tf, x0: odeint_euler_one_step( - dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps - ) - - # Integrate the state and compute optional auxiliary data over the horizon - out, aux = integrate_single_step_over_horizon( - integrator_single_step=integrator_single_step, t=t, x0=y0 - ) - - return (out, aux) if return_aux else out - - -def odeint_euler_semi_implicit( - func, - y0: ODEState, - t: TimeHorizon, - *args, - num_sub_steps: int = 1, - return_aux: bool = False -) -> Union[ODEState, Tuple[ODEState, Dict[str, Any]]]: - """ - Integrate a system of ODEs using the Semi-Implicit Euler method. - - Args: - func: A function that computes the time-derivative of the state. - y0: The initial state as ODEState object. - t: The vector of time instants of the integration horizon. - *args: Additional arguments to be passed to the function func. - num_sub_steps: The number of sub-steps to be performed within each integration step. - return_aux: Whether to return the auxiliary data produced by the integrator. - - Returns: - The state of the system at the end of the integration horizon as ODEState object, - and optionally the auxiliary data produced by the integrator. - """ - - # Close func over additional inputs and parameters - dx_dt_closure_aux = lambda x, ts: func(x, ts, *args) - - # Close one-step integration over its arguments - integrator_single_step = lambda t0, tf, x0: odeint_euler_semi_implicit_one_step( - dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps - ) - - # Integrate the state and compute optional auxiliary data over the horizon - out, aux = integrate_single_step_over_horizon( - integrator_single_step=integrator_single_step, t=t, x0=y0 - ) - - return (out, aux) if return_aux else out - - -def odeint_rk4( +def odeint( func, y0: State, t: TimeHorizon, *args, num_sub_steps: int = 1, - return_aux: bool = False -) -> Union[State, Tuple[State, Dict[str, Any]]]: + return_aux: bool = False, + integrator_type: IntegratorType = None, +): """ - Integrate a system of ODEs using the Runge-Kutta 4 method. + Integrate a system of ODEs with a fixed-step integrator. Args: func: A function that computes the time-derivative of the state. @@ -634,8 +376,13 @@ def odeint_rk4( dx_dt_closure = lambda x, ts: func(x, ts, *args) # Close one-step integration over its arguments - integrator_single_step = lambda t0, tf, x0: odeint_rk4_one_step( - dx_dt=dx_dt_closure, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps + integrator_single_step = lambda t0, tf, x0: integrator_fixed_single_step( + dx_dt=dx_dt_closure, + x0=x0, + t0=t0, + tf=tf, + num_sub_steps=num_sub_steps, + integrator_type=integrator_type, ) # Integrate the state and compute optional auxiliary data over the horizon diff --git a/src/jaxsim/simulation/ode_integration.py b/src/jaxsim/simulation/ode_integration.py index dd0b9d33d..90ad227ac 100644 --- a/src/jaxsim/simulation/ode_integration.py +++ b/src/jaxsim/simulation/ode_integration.py @@ -10,21 +10,7 @@ from jaxsim.physics.algos.terrain import FlatTerrain, Terrain from jaxsim.physics.model.physics_model import PhysicsModel from jaxsim.simulation import integrators, ode - - -class IntegratorType(enum.IntEnum): - RungeKutta4 = enum.auto() - EulerForward = enum.auto() - EulerSemiImplicit = enum.auto() - EulerSemiImplicitManifold = enum.auto() - - -_integrator_registry = { - IntegratorType.RungeKutta4: integrators.odeint_rk4, - IntegratorType.EulerForward: integrators.odeint_euler, - IntegratorType.EulerSemiImplicit: integrators.odeint_euler_semi_implicit, - IntegratorType.EulerSemiImplicitManifold: integrators.odeint_euler_semi_implicit_manifold_one_step, -} +from jaxsim.simulation.integrators import IntegratorType @jax.jit @@ -62,12 +48,13 @@ def ode_integration_fixed_step( ) # Integrate over the horizon - out = _integrator_registry[integrator_type]( + out = integrators.odeint( func=dx_dt_closure, y0=x0, t=t, num_sub_steps=num_sub_steps, return_aux=return_aux, + integrator_type=integrator_type, ) # Return output pytree and, optionally, the aux dict