Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add semi-implicit Euler scheme with quaternion integrated on manifold #73

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 170 additions & 2 deletions src/jaxsim/simulation/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from jaxsim.physics.algos.soft_contacts import SoftContactsState
from jaxsim.physics.model.physics_model_state import PhysicsModelState
from jaxsim.simulation.ode_data import ODEState
from jaxsim.sixd import se3, so3

Time = float
TimeHorizon = jtp.Vector
Time = jtp.FloatLike
TimeStep = jtp.FloatLike
TimeHorizon = jtp.VectorLike

State = jtp.PyTree
StateDerivative = jtp.PyTree
Expand Down Expand Up @@ -303,6 +305,172 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
return x_tf, aux_t0


def odeint_euler_semi_implicit_manifold_one_step(
dx_dt: StateDerivativeCallable,
x0: ODEState,
t0: Time,
tf: Time,
num_sub_steps: int = 1,
) -> Tuple[ODEState, Dict[str, Any]]:
"""
Semi-implicit Euler integrator with quaternion integration on SO(3).

Args:
dx_dt: Callable that computes the state derivative.
x0: Initial state as ODEState object.
t0: Initial time.
tf: Final time.
num_sub_steps: Number of sub-steps to break the integration into.

Returns:
A tuple having as first element the final state as ODEState object,
and as second element a dictionary including auxiliary data at t0.
"""

# Compute the sub-step size.
# We break dt in configurable sub-steps.
dt = tf - t0
sub_step_dt = dt / num_sub_steps

# Integrate the quaternion on its manifold using the new angular velocity

# Initialize the carry
Carry = Tuple[ODEState, Time]
carry_init: Carry = (x0, t0)

def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
# Unpack the carry
x_t0, t0 = carry

# Compute the state derivative.
# We only keep the quantities related to the acceleration and discard those
# related to the velocity since we are going to use those implicitly integrated
# from the accelerations.
StateDerivative = ODEState
dxdt_t0: StateDerivative = dx_dt(x_t0, t0)[0]

# Extract the initial position ∈ ℝ⁷⁺ⁿ and initial velocity ∈ ℝ⁶⁺ⁿ.
# This integrator, contrarily to most of the other ones, is not generic.
# It expects to operate on an x object of class ODEState.
pos_t0 = x_t0.physics_model.position()
vel_t0 = x_t0.physics_model.velocity()

# Extract the velocity derivative
d_vel_dt = dxdt_t0.physics_model.velocity()

# =============================================
# Perform semi-implicit Euler integration [1-4]
# =============================================

# 1. Integrate the accelerations obtaining the implicit velocities
# 2. Compute the derivative of the generalized position (w/o quaternion)
# 3. Integrate the implicit velocities (w/o quaternion)
# 4. Integrate the remaining state
# 5. Outside the loop: integrate the quaternion on SO(3) manifold

# ----------------------------------------------------------------
# 1. Integrate the accelerations obtaining the implicit velocities
# ----------------------------------------------------------------

vel_tf = vel_t0 + sub_step_dt * d_vel_dt

# ----------------------------------------------------------------------
# 2. Compute the derivative of the generalized position (w/o quaternion)
# ----------------------------------------------------------------------

# Compute the transform of the mixed base frame at t0
W_H_BW = jnp.vstack(
[
jnp.block([jnp.eye(3), jnp.vstack(x_t0.physics_model.base_position)]),
flferretti marked this conversation as resolved.
Show resolved Hide resolved
jnp.array([0, 0, 0, 1]),
]
)

# The derivative W_ṗ_B of the base position is the linear component of the
# mixed velocity B[W]_v_WB. We need to compute it from the velocity in
# inertial-fixed representation W_vl_WB.
W_v_WB = vel_tf[0:6]
BW_Xv_W = se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3]

# Compute the derivative of the generalized position excluding the quaternion
pos_no_quat_t0 = jnp.hstack([pos_t0[0:3], pos_t0[7:]])
d_pos_no_quat_tf = jnp.hstack([BW_vl_WB, vel_tf[6:]])

# -----------------------------------------------------
# 3. Integrate the implicit velocities (w/o quaternion)
# -----------------------------------------------------

pos_no_quat_tf = pos_no_quat_t0 + sub_step_dt * d_pos_no_quat_tf

# ---------------------------------
# 4. Integrate the remaining state
# ---------------------------------

# Integrate the derivative of the tangential material deformation
m = x_t0.soft_contacts.tangential_deformation
ṁ = dxdt_t0.soft_contacts.tangential_deformation
tangential_deformation_tf = m + sub_step_dt * ṁ

# Pack the new state into an ODEState object.
# We store a zero quaternion as placeholder, it will be replaced later.
x_tf = ODEState(
physics_model=PhysicsModelState(
base_position=pos_no_quat_tf[0:3],
base_quaternion=jnp.zeros_like(x_t0.physics_model.base_quaternion),
joint_positions=pos_no_quat_tf[3:],
base_linear_velocity=vel_tf[0:3],
base_angular_velocity=vel_tf[3:6],
joint_velocities=vel_tf[6:],
),
soft_contacts=SoftContactsState(
tangential_deformation=tangential_deformation_tf
),
)

# Update the time
tf = t0 + sub_step_dt

# Pack the carry
carry = (x_tf, tf)

return carry, None

# Integrate over the given horizon
(x_no_quat_tf, _), _ = jax.lax.scan(
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
)

# ---------------------------------------------
# 5. Integrate the quaternion on SO(3) manifold
# ---------------------------------------------

# Indices to convert quaternions between serializations
to_xyzw = jnp.array([1, 2, 3, 0])
to_wxyz = jnp.array([3, 0, 1, 2])

# Get the initial quaternion and the implicitly integrated angular velocity
W_ω_WB_tf = x_no_quat_tf.physics_model.base_angular_velocity
W_Q_B_t0 = so3.SO3.from_quaternion_xyzw(x0.physics_model.base_quaternion[to_xyzw])

# Integrate the quaternion on its manifold using the implicit angular velocity,
# transformed in body-fixed representation since jaxlie uses this convention
B_R_W = W_Q_B_t0.inverse().as_matrix()
W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf)

# Store the quaternion in the final state
x_tf = x_no_quat_tf.replace(
physics_model=x_no_quat_tf.physics_model.replace(
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
)
)

# Compute the aux dictionary at t0
_, aux_t0 = dx_dt(x0, t0)

return x_tf, aux_t0


# ===============================
# Adapter: single step -> horizon
# ===============================
Expand Down
32 changes: 32 additions & 0 deletions src/jaxsim/simulation/ode_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class IntegratorType(enum.IntEnum):
RungeKutta4 = enum.auto()
EulerForward = enum.auto()
EulerSemiImplicit = enum.auto()
EulerSemiImplicitManifold = enum.auto()


@jax.jit
Expand Down Expand Up @@ -94,6 +95,37 @@ def ode_integration_euler_semi_implicit(
return (state, out[1]) if return_aux else state


@functools.partial(jax.jit, static_argnames=["num_sub_steps", "return_aux"])
def ode_integration_euler_semi_implicit_manifold(
x0: ode.ode_data.ODEState,
t: integrators.TimeHorizon,
physics_model: PhysicsModel,
soft_contacts_params: SoftContactsParams = SoftContactsParams(),
terrain: Terrain = FlatTerrain(),
ode_input: ode.ode_data.ODEInput = None,
*args,
num_sub_steps: int = 1,
return_aux: bool = False,
) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict[str, Any]]]:
# Close func over additional inputs and parameters
dx_dt_closure = lambda x, ts: ode.dx_dt(
x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
)

# Integrate over the horizon
out = integrators.odeint_euler_semi_implicit_manifold(
func=dx_dt_closure,
y0=x0,
t=t,
num_sub_steps=num_sub_steps,
return_aux=return_aux,
)

# Return output pytree and, optionally, the aux dict
state = out if not return_aux else out[0]
return (state, out[1]) if return_aux else state


@functools.partial(jax.jit, static_argnames=["num_sub_steps", "return_aux"])
def ode_integration_rk4(
x0: ode.ode_data.ODEState,
Expand Down
Loading