Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute the pose and the jacobian of the implicit frames associated to collidable points #163

Merged
merged 5 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 128 additions & 2 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def collidable_point_kinematics(

from jaxsim.rbda import collidable_points

# Switch to inertial-fixed since the RBDAs expect velocities in this representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel(
model=model,
Expand Down Expand Up @@ -61,7 +62,9 @@ def collidable_point_positions(
The position of the collidable points in the world frame.
"""

return collidable_point_kinematics(model=model, data=data)[0]
W_p_Ci, _ = collidable_point_kinematics(model=model, data=data)

return W_p_Ci


@jax.jit
Expand All @@ -79,7 +82,9 @@ def collidable_point_velocities(
The 3D velocity of the collidable points.
"""

return collidable_point_kinematics(model=model, data=data)[1]
_, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data)

return W_ṗ_Ci


@jax.jit
Expand Down Expand Up @@ -269,3 +274,124 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
)

return sc_parameters


@jax.jit
def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
r"""
Return the pose of the collidable points.
flferretti marked this conversation as resolved.
Show resolved Hide resolved

Args:
model: The model to consider.
data: The data of the considered model.

Returns:
The stacked SE(3) matrices of all collidable points.

Note:
Each collidable point is implicitly associated with a frame
:math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
collidable point and :math:`[L]` is the orientation frame of the link it is
rigidly attached to.
"""

# Get the transforms of the parent link of all collidable points.
W_H_L = jax.vmap(
lambda parent_link_idx: js.link.transform(
model=model, data=data, link_index=parent_link_idx
)
)(jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int))

# Build the link-to-point transform from the displacement between the link frame L
# and the implicit contact frame C.
L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(
model.kin_dyn_parameters.contact_parameters.point
)

# Compose the work-to-link and link-to-point transforms.
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)


@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Array:
r"""
Return the free-floating Jacobian of the collidable points.

Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr:
The output velocity representation of the free-floating jacobian.

Returns:
The stacked 6×(6+n) free-floating jacobians of the frames associated to the
collidable points.

Note:
Each collidable point is implicitly associated with a frame
:math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
collidable point and :math:`[L]` is the orientation frame of the link it is
rigidly attached to.
"""

output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

# For each collidable point, get the Jacobians of their parent link.
# In inertial-fixed output representation, the Jacobian of the parent link is also
# the Jacobian of the frame C implicitly associated with the collidable point.
W_J_WC = W_J_WL = jax.vmap(
lambda parent_link_idx: js.link.jacobian(
model=model,
data=data,
link_index=parent_link_idx,
output_vel_repr=VelRepr.Inertial,
)
)(jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int))

# Adjust the output representation.
match output_vel_repr:

case VelRepr.Inertial:
O_J_WC = W_J_WC

case VelRepr.Body:

W_H_C = transforms(model=model, data=data)

def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
C_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_C, inverse=True
)
C_J_WCi = C_X_W @ W_J_WC
return C_J_WCi
flferretti marked this conversation as resolved.
Show resolved Hide resolved

O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)

case VelRepr.Mixed:

W_H_C = transforms(model=model, data=data)

def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:

W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))

CW_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_CW, inverse=True
)

CW_J_WC = CW_X_W @ W_J_WC
return CW_J_WC

O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)

case _:
raise ValueError(output_vel_repr)

return O_J_WC
37 changes: 37 additions & 0 deletions tests/test_contact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import jax
import pytest

import jaxsim.api as js
from jaxsim import VelRepr


def test_collidable_point_jacobians(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)

# =====
# Tests
# =====

# Compute the velocity of the collidable points with a RBDA.
# This function always returns the linear part of the mixed velocity of the
# implicit frame C corresponding to the collidable point.
W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data)

# Compute the generalized velocity and the free-floating Jacobian of the frame C.
ν = data.generalized_velocity()
CL_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? I was expecting it to be ${} ^{C[W]} J _{W, C}$ 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep you're right, that's the mixed velocity and I should have named the variable CW_J_WC. Thanks!


# Compute the velocity of the collidable points using the Jacobians.
v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CL_J_WC, ν)

assert W_ṗ_C == pytest.approx(v_WC_from_jax[:, 0:3])
Loading