Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Jacobian with jax.lax.scan #16

Merged
merged 1 commit into from
Sep 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 63 additions & 13 deletions src/jaxsim/physics/algos/jacobian.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as np

Expand All @@ -7,11 +10,7 @@
from . import utils


def jacobian(
model: PhysicsModel,
body_index: int,
q: jtp.Vector,
) -> jtp.Matrix:
def jacobian(model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector) -> jtp.Matrix:

_, q, _, _, _, _ = utils.process_inputs(physics_model=model, q=q)

Expand All @@ -23,27 +22,78 @@ def jacobian(
i_X_0 = jnp.zeros_like(i_X_pre)
i_X_0 = i_X_0.at[0].set(jnp.eye(6))

for i in np.arange(start=1, stop=model.NB):
# Parent array mapping: i -> λ(i).
# Exception: λ(0) must not be used, it's initialized to -1.
λ = model.parent

# ====================
# Propagate kinematics
# ====================

PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
propagate_kinematics_carry = (i_X_λi, i_X_0)

def propagate_kinematics(
carry: PropagateKinematicsCarry, i: jtp.Int
) -> Tuple[PropagateKinematicsCarry, None]:

i_X_λi, i_X_0 = carry

i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]]
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
i_X_0 = i_X_0.at[i].set(i_X_0_i)

return (i_X_λi, i_X_0), None

(i_X_λi, i_X_0), _ = jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=np.arange(start=1, stop=model.NB),
)

# ============================
# Compute doubly-left Jacobian
# ============================

J = jnp.zeros(shape=(6, 6 + model.dofs()))

Jb = i_X_0[body_index]
J = J.at[0:6, 0:6].set(Jb)

for i in reversed(model.support_body_array(body_index=body_index)):
ComputeJacobianCarry = jtp.MatrixJax
compute_jacobian_carry = J

def compute_jacobian(
carry: ComputeJacobianCarry, i: jtp.Int
) -> Tuple[ComputeJacobianCarry, None]:
def update_jacobian(
carry: Tuple[ComputeJacobianCarry, jtp.Int]
) -> ComputeJacobianCarry:

J, i = carry

ii = i - 1

Js_i = i_X_0[body_index] @ jnp.linalg.inv(i_X_0[i]) @ S[i]
J = J.at[0:6, 6 + ii].set(Js_i.squeeze())

return J

ii = i - 1
carry = jax.lax.cond(
pred=(jnp.any(i == model.support_body_array(body_index=body_index))),
true_fun=update_jacobian,
false_fun=lambda carry_i: carry_i[0],
operand=(carry, i),
)

if i == 0:
break
return carry, None

Js_i = i_X_0[body_index] @ jnp.linalg.inv(i_X_0[i]) @ S[i]
J = J.at[0:6, 6 + ii].set(Js_i.squeeze())
J, _ = jax.lax.scan(
f=compute_jacobian,
init=compute_jacobian_carry,
xs=np.arange(start=1, stop=model.NB),
)

return J