Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 126 additions & 144 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,9 +923,9 @@ class HwLinkMetadata(JaxsimDataclass):
Class storing the hardware parameters of a link.

Attributes:
shape: The shape of the link.
link_shape: The shape of the link.
0 = box, 1 = cylinder, 2 = sphere, -1 = unsupported.
dims: The dimensions of the link.
geometry: The dimensions of the link.
box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0].
density: The density of the link.
L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G.
Expand All @@ -934,8 +934,8 @@ class HwLinkMetadata(JaxsimDataclass):
L_H_pre: The homogeneous transforms for child joints.
"""

shape: jtp.Vector
dims: jtp.Vector
link_shape: jtp.Vector
geometry: jtp.Vector
density: jtp.Float
L_H_G: jtp.Matrix
L_H_vis: jtp.Matrix
Expand Down Expand Up @@ -963,68 +963,64 @@ def compute_mass_and_inertia(
- inertia: The computed inertia tensor of the hardware link.
"""

mass, inertia = jax.lax.switch(
hw_link_metadata.shape,
[
HwLinkMetadata._box,
HwLinkMetadata._cylinder,
HwLinkMetadata._sphere,
],
hw_link_metadata.dims,
hw_link_metadata.density,
)
return mass, inertia
def box(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
lx, ly, lz = dims

@staticmethod
def _box(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
lx, ly, lz = dims
mass = density * lx * ly * lz

mass = density * lx * ly * lz
inertia = jnp.array(
[
[mass * (ly**2 + lz**2) / 12, 0, 0],
[0, mass * (lx**2 + lz**2) / 12, 0],
[0, 0, mass * (lx**2 + ly**2) / 12],
]
)
return mass, inertia

inertia = jnp.array(
[
[mass * (ly**2 + lz**2) / 12, 0, 0],
[0, mass * (lx**2 + lz**2) / 12, 0],
[0, 0, mass * (lx**2 + ly**2) / 12],
]
)
return mass, inertia
def cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
r, l, _ = dims

@staticmethod
def _cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
r, l, _ = dims
mass = density * (jnp.pi * r**2 * l)

mass = density * (jnp.pi * r**2 * l)
inertia = jnp.array(
[
[mass * (3 * r**2 + l**2) / 12, 0, 0],
[0, mass * (3 * r**2 + l**2) / 12, 0],
[0, 0, mass * (r**2) / 2],
]
)

inertia = jnp.array(
[
[mass * (3 * r**2 + l**2) / 12, 0, 0],
[0, mass * (3 * r**2 + l**2) / 12, 0],
[0, 0, mass * (r**2) / 2],
]
)
return mass, inertia

return mass, inertia
def sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
r = dims[0]

@staticmethod
def _sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
r = dims[0]
mass = density * (4 / 3 * jnp.pi * r**3)

inertia = jnp.eye(3) * (2 / 5 * mass * r**2)

mass = density * (4 / 3 * jnp.pi * r**3)
return mass, inertia

inertia = jnp.eye(3) * (2 / 5 * mass * r**2)
def compute_mass_inertia(shape_idx, dims, density):
return jax.lax.switch(shape_idx, (box, cylinder, sphere), dims, density)

mass, inertia = jax.vmap(compute_mass_inertia)(
hw_link_metadata.link_shape,
hw_link_metadata.geometry,
hw_link_metadata.density,
)

return mass, inertia

@staticmethod
def _convert_scaling_to_3d_vector(
shape: jtp.Int, scaling_factors: jtp.Vector
link_shapes: jtp.Int, scaling_factors: jtp.Vector
) -> jtp.Vector:
"""
Convert scaling factors for specific shape dimensions into a 3D scaling vector.

Args:
shape: The shape of the link (e.g., box, sphere, cylinder).
link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder).
scaling_factors: The scaling factors for the shape dimensions.

Returns:
Expand All @@ -1036,38 +1032,27 @@ def _convert_scaling_to_3d_vector(
- Cylinder: [r, r, l]
- Sphere: [r, r, r]
"""
return jax.lax.switch(
shape,
branches=[
# Box
lambda: jnp.array(
[
scaling_factors[0],
scaling_factors[1],
scaling_factors[2],
]
),
# Cylinder
lambda: jnp.array(
[
scaling_factors[0],
scaling_factors[0],
scaling_factors[1],
]
),
# Sphere
lambda: jnp.array(
[
scaling_factors[0],
scaling_factors[0],
scaling_factors[0],
]
),
],

# Index mapping for each shape type (link_shapes x 3 dims)
# Box: [lx, ly, lz] -> [0, 1, 2]
# Cylinder: [r, r, l] -> [0, 0, 1]
# Sphere: [r, r, r] -> [0, 0, 0]
shape_indices = jnp.array(
[
[0, 1, 2], # Box
[0, 0, 1], # Cylinder
[0, 0, 0], # Sphere
]
)

# For each link, get the index vector for its shape
per_link_indices = shape_indices[link_shapes]

# Gather dims per link according to per_link_indices
return scaling_factors.dims[per_link_indices.squeeze()]

@staticmethod
def compute_inertia_link(I_com, mass, L_H_G) -> jtp.Matrix:
def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix:
"""
Compute the inertia tensor of the link based on its shape and mass.
"""
Expand All @@ -1077,96 +1062,93 @@ def compute_inertia_link(I_com, mass, L_H_G) -> jtp.Matrix:

@staticmethod
def apply_scaling(
hw_metadata: HwLinkMetadata, scaling_factors: ScalingFactors
has_joints: bool,
hw_metadata: HwLinkMetadata,
scaling_factors: ScalingFactors,
) -> HwLinkMetadata:
"""
Apply scaling to the hardware parameters and return a new HwLinkMetadata object.

Args:
has_joints: A boolean indicating if the model has joints.
hw_metadata: the original HwLinkMetadata object.
scaling_factors: the scaling factors to apply.

Returns:
A new HwLinkMetadata object with updated parameters.
"""

# ==================================
# Handle unsupported links
# ==================================
def unsupported_case(hw_metadata, scaling_factors):
# Return the metadata unchanged for unsupported links
return hw_metadata

def supported_case(hw_metadata, scaling_factors):
# ==================================
# Update the kinematics of the link
# ==================================

# Get the nominal transforms
L_H_G = hw_metadata.L_H_G
L_H_vis = hw_metadata.L_H_vis
L_H_pre_array = hw_metadata.L_H_pre
L_H_pre_mask = hw_metadata.L_H_pre_mask

# Compute the 3D scaling vector
scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector(
hw_metadata.shape, scaling_factors.dims
)
scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector(
hw_metadata.link_shape, scaling_factors
)

# Express the transforms in the G frame
G_H_L = jaxsim.math.Transform.inverse(L_H_G)
G_H_vis = G_H_L @ L_H_vis
G_H_pre_array = jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array)

# Apply the scaling to the position vectors
G_H̅_L = G_H_L.at[:3, 3].set(scale_vector * G_H_L[:3, 3])
G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3])
# Apply scaling to the position vectors in G_H_pre_array based on the mask
G_H̅_pre_array = jax.vmap(
lambda G_H_pre, mask: jnp.where(
# Expand mask for broadcasting
mask[..., None, None],
# Apply scaling
G_H_pre.at[:3, 3].set(scale_vector * G_H_pre[:3, 3]),
# Keep unchanged if mask is False
G_H_pre,
)
)(G_H_pre_array, L_H_pre_mask)
# =================================
# Update the kinematics of the link
# =================================

# Get back to the link frame
L_H̅_G = jaxsim.math.Transform.inverse(G_H̅_L)
L_H̅_vis = L_H̅_G @ G_H̅_vis
L_H̅_pre_array = jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array)
# Get the nominal transforms
L_H_G = hw_metadata.L_H_G
L_H_vis = hw_metadata.L_H_vis
L_H_pre_array = hw_metadata.L_H_pre
L_H_pre_mask = hw_metadata.L_H_pre_mask

# ============================
# Update the shape parameters
# ============================
# Express the transforms in the G frame
G_H_L = jaxsim.math.Transform.inverse(L_H_G)
G_H_vis = G_H_L @ L_H_vis

updated_dims = hw_metadata.dims * scaling_factors.dims
G_H_pre_array = (
jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array)
if has_joints
else L_H_pre_array
)

# ==============================
# Scale the density of the link
# ==============================
# Apply the scaling to the position vectors
G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3])

updated_density = hw_metadata.density * scaling_factors.density
# Apply scaling to the position vectors in G_H_pre_array based on the mask
G_H̅_pre_array = (
G_H_pre_array.at[:, :3, 3].set(
jnp.where(
L_H_pre_mask[:, None],
scale_vector[None, :] * G_H_pre_array[:, :3, 3],
G_H_pre_array[:, :3, 3],
)
)
if has_joints
else G_H_pre_array
)

# ============================
# Return updated HwLinkMetadata
# ============================
# Get back to the link frame
L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3])
L_H̅_vis = L_H̅_G @ G_H̅_vis
L_H̅_pre_array = (
jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array)
if has_joints
else G_H̅_pre_array
)

return hw_metadata.replace(
dims=updated_dims,
density=updated_density,
L_H_G=L_H̅_G,
L_H_vis=L_H̅_vis,
L_H_pre=L_H̅_pre_array,
)
# ===========================
# Update the shape parameters
# ===========================

updated_geoms = hw_metadata.geometry * scaling_factors.dims

# =============================
# Scale the density of the link
# =============================

updated_density = hw_metadata.density * scaling_factors.density

# =============================
# Return updated HwLinkMetadata
# =============================

# Use jax.lax.cond to handle unsupported links
return jax.lax.cond(
hw_metadata.shape == LinkParametrizableShape.Unsupported,
lambda: unsupported_case(hw_metadata, scaling_factors),
lambda: supported_case(hw_metadata, scaling_factors),
return hw_metadata.replace(
geometry=updated_geoms,
density=updated_density,
L_H_G=L_H̅_G,
L_H_vis=L_H̅_vis,
L_H_pre=L_H̅_pre_array,
)


Expand Down
Loading
Loading