From a36797a86a2c95237f8c9a9e4c7d4dd18793a069 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 19 May 2025 14:16:17 +0100 Subject: [PATCH 01/19] Simplify and speedup transform update --- src/jaxsim/api/model.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index ea513fdbe..30d1656b3 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2389,22 +2389,18 @@ def update_λ_H_pre(joint_index): L_H_pre_for_joint = updated_hw_link_metadata.L_H_pre[:, joint_index] L_H_pre_mask_for_joint = updated_hw_link_metadata.L_H_pre_mask[:, joint_index] - # Use the mask to select the first valid transform or fall back to the original - valid_transforms = jnp.where( - L_H_pre_mask_for_joint[:, None, None], # Expand mask for broadcasting - L_H_pre_for_joint, # Use the transform if the mask is True - jnp.zeros_like(L_H_pre_for_joint), # Otherwise, use a zero matrix - ) + # Select the first valid transform (if any) using the mask + first_valid_index = jnp.argmax(L_H_pre_mask_for_joint) + selected_transform = L_H_pre_for_joint[first_valid_index] - # Sum the valid transforms (only one will be non-zero due to the mask) - selected_transform = jnp.sum(valid_transforms, axis=0) + # Check if any valid transform exists + has_valid_transform = L_H_pre_mask_for_joint.any() - # If no valid transform exists, fall back to the original λ_H_pre - return jax.lax.cond( - jnp.any(L_H_pre_mask_for_joint), - lambda: selected_transform, - lambda: kin_dyn_params.joint_model.λ_H_pre[joint_index + 1], - ) + # Fallback to the original λ_H_pre if no valid transform exists + fallback_transform = kin_dyn_params.joint_model.λ_H_pre[joint_index + 1] + + # Return the selected transform or fallback + return jnp.where(has_valid_transform, selected_transform, fallback_transform) # Apply the update function to all joint indices updated_λ_H_pre = jax.vmap(update_λ_H_pre)( From 0d6133b3ea8db198512d7b554b6951b6e68ce525 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 19 May 2025 17:39:55 +0100 Subject: [PATCH 02/19] Remove extra argument in `compute_inertia_link` and lint --- src/jaxsim/api/kin_dyn_parameters.py | 28 ++++++++++++++++------------ src/jaxsim/api/model.py | 2 ++ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 52e723135..cdfb48ed6 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1067,7 +1067,7 @@ def _convert_scaling_to_3d_vector( ) @staticmethod - def compute_inertia_link(I_com, mass, L_H_G) -> jtp.Matrix: + def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix: """ Compute the inertia tensor of the link based on its shape and mass. """ @@ -1090,17 +1090,20 @@ def apply_scaling( A new HwLinkMetadata object with updated parameters. """ - # ================================== - # Handle unsupported links - # ================================== def unsupported_case(hw_metadata, scaling_factors): + + # ======================== + # Handle unsupported links + # ======================== + # Return the metadata unchanged for unsupported links return hw_metadata def supported_case(hw_metadata, scaling_factors): - # ================================== + + # ================================= # Update the kinematics of the link - # ================================== + # ================================= # Get the nominal transforms L_H_G = hw_metadata.L_H_G @@ -1121,6 +1124,7 @@ def supported_case(hw_metadata, scaling_factors): # Apply the scaling to the position vectors G_H̅_L = G_H_L.at[:3, 3].set(scale_vector * G_H_L[:3, 3]) G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) + # Apply scaling to the position vectors in G_H_pre_array based on the mask G_H̅_pre_array = jax.vmap( lambda G_H_pre, mask: jnp.where( @@ -1138,21 +1142,21 @@ def supported_case(hw_metadata, scaling_factors): L_H̅_vis = L_H̅_G @ G_H̅_vis L_H̅_pre_array = jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array) - # ============================ + # =========================== # Update the shape parameters - # ============================ + # =========================== updated_dims = hw_metadata.dims * scaling_factors.dims - # ============================== + # ============================= # Scale the density of the link - # ============================== + # ============================= updated_density = hw_metadata.density * scaling_factors.density - # ============================ + # ============================= # Return updated HwLinkMetadata - # ============================ + # ============================= return hw_metadata.replace( dims=updated_dims, diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 30d1656b3..d5df2b861 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2406,11 +2406,13 @@ def update_λ_H_pre(joint_index): updated_λ_H_pre = jax.vmap(update_λ_H_pre)( jnp.arange(kin_dyn_params.number_of_joints()) ) + # NOTE: λ_H_pre should be of len (1+n_joints) with the 0-th element equal # to identity to represent the world-to-base tree transform. See JointModel class updated_λ_H_pre_with_base = jnp.concatenate( (jnp.eye(4).reshape(1, 4, 4), updated_λ_H_pre), axis=0 ) + # Replace the joint model with the updated transforms updated_joint_model = kin_dyn_params.joint_model.replace( λ_H_pre=updated_λ_H_pre_with_base From 209bcb6d3c373473523dc487e4a5fa4ae51284be Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 May 2025 13:07:07 +0100 Subject: [PATCH 03/19] Add parametrization support for joint-free models --- src/jaxsim/api/kin_dyn_parameters.py | 37 +++++++++++++++++++--------- src/jaxsim/api/model.py | 34 ++++++++++++++----------- 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index cdfb48ed6..21d0481ea 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1111,6 +1111,9 @@ def supported_case(hw_metadata, scaling_factors): L_H_pre_array = hw_metadata.L_H_pre L_H_pre_mask = hw_metadata.L_H_pre_mask + # Check if the link has joints + has_joints = L_H_pre_array.shape != (0,) + # Compute the 3D scaling vector scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( hw_metadata.shape, scaling_factors.dims @@ -1119,28 +1122,38 @@ def supported_case(hw_metadata, scaling_factors): # Express the transforms in the G frame G_H_L = jaxsim.math.Transform.inverse(L_H_G) G_H_vis = G_H_L @ L_H_vis - G_H_pre_array = jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array) + + G_H_pre_array = ( + jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array) + if has_joints + else L_H_pre_array + ) # Apply the scaling to the position vectors G_H̅_L = G_H_L.at[:3, 3].set(scale_vector * G_H_L[:3, 3]) G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) # Apply scaling to the position vectors in G_H_pre_array based on the mask - G_H̅_pre_array = jax.vmap( - lambda G_H_pre, mask: jnp.where( - # Expand mask for broadcasting - mask[..., None, None], - # Apply scaling - G_H_pre.at[:3, 3].set(scale_vector * G_H_pre[:3, 3]), - # Keep unchanged if mask is False - G_H_pre, - ) - )(G_H_pre_array, L_H_pre_mask) + G_H̅_pre_array = ( + jax.vmap( + lambda G_H_pre, mask: jnp.where( + mask[..., None, None], + G_H_pre.at[:3, 3].set(scale_vector * G_H_pre[:3, 3]), + G_H_pre, + ) + )(G_H_pre_array, L_H_pre_mask) + if has_joints + else G_H_pre_array + ) # Get back to the link frame L_H̅_G = jaxsim.math.Transform.inverse(G_H̅_L) L_H̅_vis = L_H̅_G @ G_H̅_vis - L_H̅_pre_array = jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array) + L_H̅_pre_array = ( + jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array) + if has_joints + else G_H̅_pre_array + ) # =========================== # Update the shape parameters diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index d5df2b861..249985725 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -472,7 +472,7 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: L_H_pre_masks.append( [ int(joint_index in child_joints_indices) - for joint_index in range(0, self.number_of_joints()) + for joint_index in range(self.number_of_joints()) ] ) L_H_pre.append( @@ -482,7 +482,7 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: if joint_index in child_joints_indices else jnp.eye(4) ) - for joint_index in range(0, self.number_of_joints()) + for joint_index in range(self.number_of_joints()) ] ) @@ -2402,21 +2402,25 @@ def update_λ_H_pre(joint_index): # Return the selected transform or fallback return jnp.where(has_valid_transform, selected_transform, fallback_transform) - # Apply the update function to all joint indices - updated_λ_H_pre = jax.vmap(update_λ_H_pre)( - jnp.arange(kin_dyn_params.number_of_joints()) - ) + if model.number_of_joints() > 0: + # Apply the update function to all joint indices + updated_λ_H_pre = jax.vmap(update_λ_H_pre)( + jnp.arange(kin_dyn_params.number_of_joints()) + ) - # NOTE: λ_H_pre should be of len (1+n_joints) with the 0-th element equal - # to identity to represent the world-to-base tree transform. See JointModel class - updated_λ_H_pre_with_base = jnp.concatenate( - (jnp.eye(4).reshape(1, 4, 4), updated_λ_H_pre), axis=0 - ) + # NOTE: λ_H_pre should be of len (1+n_joints) with the 0-th element equal + # to identity to represent the world-to-base tree transform. See JointModel class + updated_λ_H_pre_with_base = jnp.concatenate( + (jnp.eye(4).reshape(1, 4, 4), updated_λ_H_pre), axis=0 + ) - # Replace the joint model with the updated transforms - updated_joint_model = kin_dyn_params.joint_model.replace( - λ_H_pre=updated_λ_H_pre_with_base - ) + # Replace the joint model with the updated transforms + updated_joint_model = kin_dyn_params.joint_model.replace( + λ_H_pre=updated_λ_H_pre_with_base + ) + else: + # If there are no joints, we can just use the identity transform + updated_joint_model = kin_dyn_params.joint_model # Replace the kin_dyn_parameters with updated values updated_kin_dyn_params = kin_dyn_params.replace( From f773c27720e91a794044c46bdf413cc95a0ffa7a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 May 2025 13:09:44 +0100 Subject: [PATCH 04/19] Avoid inverting twice `L_H_G` transform --- src/jaxsim/api/kin_dyn_parameters.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 21d0481ea..eefe11cf5 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1130,7 +1130,6 @@ def supported_case(hw_metadata, scaling_factors): ) # Apply the scaling to the position vectors - G_H̅_L = G_H_L.at[:3, 3].set(scale_vector * G_H_L[:3, 3]) G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) # Apply scaling to the position vectors in G_H_pre_array based on the mask @@ -1147,7 +1146,7 @@ def supported_case(hw_metadata, scaling_factors): ) # Get back to the link frame - L_H̅_G = jaxsim.math.Transform.inverse(G_H̅_L) + L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3]) L_H̅_vis = L_H̅_G @ G_H̅_vis L_H̅_pre_array = ( jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array) From 5fc1b0b4d48902cfa78591dc4f256e7d8f3599a2 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 May 2025 13:23:05 +0100 Subject: [PATCH 05/19] Avoid creating two full copies of `jnp.where` branch --- src/jaxsim/api/kin_dyn_parameters.py | 20 ++++++++++---------- src/jaxsim/api/model.py | 10 +++++++--- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index eefe11cf5..b1b0f6ca1 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1077,12 +1077,15 @@ def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix: @staticmethod def apply_scaling( - hw_metadata: HwLinkMetadata, scaling_factors: ScalingFactors + has_joints: bool, + hw_metadata: HwLinkMetadata, + scaling_factors: ScalingFactors, ) -> HwLinkMetadata: """ Apply scaling to the hardware parameters and return a new HwLinkMetadata object. Args: + has_joints: A boolean indicating if the model has joints. hw_metadata: the original HwLinkMetadata object. scaling_factors: the scaling factors to apply. @@ -1111,9 +1114,6 @@ def supported_case(hw_metadata, scaling_factors): L_H_pre_array = hw_metadata.L_H_pre L_H_pre_mask = hw_metadata.L_H_pre_mask - # Check if the link has joints - has_joints = L_H_pre_array.shape != (0,) - # Compute the 3D scaling vector scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( hw_metadata.shape, scaling_factors.dims @@ -1134,13 +1134,13 @@ def supported_case(hw_metadata, scaling_factors): # Apply scaling to the position vectors in G_H_pre_array based on the mask G_H̅_pre_array = ( - jax.vmap( - lambda G_H_pre, mask: jnp.where( - mask[..., None, None], - G_H_pre.at[:3, 3].set(scale_vector * G_H_pre[:3, 3]), - G_H_pre, + G_H_pre_array.at[:, :3, 3].set( + jnp.where( + L_H_pre_mask[:, None], + scale_vector[None, :] * G_H_pre_array[:, :3, 3], + G_H_pre_array[:, :3, 3], ) - )(G_H_pre_array, L_H_pre_mask) + ) if has_joints else G_H_pre_array ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 249985725..a3443f966 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2356,9 +2356,13 @@ def update_hw_parameters( link_parameters: LinkParameters = kin_dyn_params.link_parameters hw_link_metadata: HwLinkMetadata = kin_dyn_params.hw_link_metadata + has_joints = model.number_of_joints() > 0 + # Apply scaling to hw_link_metadata using vmap - updated_hw_link_metadata = jax.vmap(HwLinkMetadata.apply_scaling)( - hw_link_metadata, scaling_factors + updated_hw_link_metadata = jax.vmap(HwLinkMetadata.apply_scaling, in_axes=(None,))( + has_joints, + hw_metadata=hw_link_metadata, + scaling_factors=scaling_factors, ) # Compute mass and inertia once and unpack the results @@ -2402,7 +2406,7 @@ def update_λ_H_pre(joint_index): # Return the selected transform or fallback return jnp.where(has_valid_transform, selected_transform, fallback_transform) - if model.number_of_joints() > 0: + if has_joints: # Apply the update function to all joint indices updated_λ_H_pre = jax.vmap(update_λ_H_pre)( jnp.arange(kin_dyn_params.number_of_joints()) From 56c364617f2877350003917f6caf9fff7f8e247a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 May 2025 15:30:08 +0100 Subject: [PATCH 06/19] Encapsulate `HwLinkMetadata.shape` as a static attribute --- src/jaxsim/api/kin_dyn_parameters.py | 9 ++++++++- src/jaxsim/api/model.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index b1b0f6ca1..ce5469d88 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -934,7 +934,7 @@ class HwLinkMetadata(JaxsimDataclass): L_H_pre: The homogeneous transforms for child joints. """ - shape: jtp.Vector + _shape: Static[tuple[int]] dims: jtp.Vector density: jtp.Float L_H_G: jtp.Matrix @@ -942,6 +942,13 @@ class HwLinkMetadata(JaxsimDataclass): L_H_pre_mask: jtp.Vector L_H_pre: jtp.Matrix + @property + def shape(self) -> int: + """ + Return the shape of the link. + """ + return np.array(self._shape) + @staticmethod def compute_mass_and_inertia( hw_link_metadata: HwLinkMetadata, diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index a3443f966..247d6589d 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -488,7 +488,7 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: # Stack collected data into JAX arrays return HwLinkMetadata( - shape=jnp.array(shapes, dtype=int), + _shape=shapes, dims=jnp.array(dims, dtype=float), density=jnp.array(densities, dtype=float), L_H_G=jnp.array(L_H_Gs, dtype=float), From 2aca2ad9273cdc4709697ad6c1e5cc56340758ab Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 May 2025 15:32:52 +0100 Subject: [PATCH 07/19] Reduce memory footprint of `_convert_scaling_to_3d_vector` --- src/jaxsim/api/kin_dyn_parameters.py | 46 ++++++++++------------------ src/jaxsim/api/model.py | 4 +++ 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index ce5469d88..3f02fadf2 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1025,13 +1025,13 @@ def _sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]: @staticmethod def _convert_scaling_to_3d_vector( - shape: jtp.Int, scaling_factors: jtp.Vector + shape_types: jtp.Int, scaling_factors: jtp.Vector ) -> jtp.Vector: """ Convert scaling factors for specific shape dimensions into a 3D scaling vector. Args: - shape: The shape of the link (e.g., box, sphere, cylinder). + shape_types: The shape_types of the link (e.g., box, sphere, cylinder). scaling_factors: The scaling factors for the shape dimensions. Returns: @@ -1043,36 +1043,22 @@ def _convert_scaling_to_3d_vector( - Cylinder: [r, r, l] - Sphere: [r, r, r] """ - return jax.lax.switch( - shape, - branches=[ - # Box - lambda: jnp.array( - [ - scaling_factors[0], - scaling_factors[1], - scaling_factors[2], - ] - ), - # Cylinder - lambda: jnp.array( - [ - scaling_factors[0], - scaling_factors[0], - scaling_factors[1], - ] - ), - # Sphere - lambda: jnp.array( - [ - scaling_factors[0], - scaling_factors[0], - scaling_factors[0], - ] - ), - ], + + # Index mapping for each shape type (shape_type x 3 dims) + shape_indices = jnp.array( + [ + [0, 1, 2], # Box + [0, 0, 1], # Cylinder + [0, 0, 0], # Sphere + ] ) + # For each link, get the index vector for its shape + per_link_indices = shape_indices[shape_types] + + # Gather dims per link according to per_link_indices + return jnp.take_along_axis(scaling_factors.dims, per_link_indices, axis=1) + @staticmethod def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix: """ diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 247d6589d..7d3036411 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2358,6 +2358,10 @@ def update_hw_parameters( has_joints = model.number_of_joints() > 0 + scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( + hw_link_metadata.shape, scaling_factors + ) + # Apply scaling to hw_link_metadata using vmap updated_hw_link_metadata = jax.vmap(HwLinkMetadata.apply_scaling, in_axes=(None,))( has_joints, From 0543a163a7679cdfe4dd94be1ac7dbb9037bf5e2 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 May 2025 15:35:11 +0100 Subject: [PATCH 08/19] Refactor `compute_mass_and_inertia` to use static shapes --- src/jaxsim/api/kin_dyn_parameters.py | 78 ++++++++++++++-------------- 1 file changed, 38 insertions(+), 40 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 3f02fadf2..f28d7b59c 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -951,6 +951,7 @@ def shape(self) -> int: @staticmethod def compute_mass_and_inertia( + shape_types: jtp.Array, hw_link_metadata: HwLinkMetadata, ) -> tuple[jtp.Float, jtp.Matrix]: """ @@ -961,6 +962,7 @@ def compute_mass_and_inertia( by using shape-specific methods. Args: + shape_types: The shape types of the link (e.g., box, sphere, cylinder). hw_link_metadata: Metadata describing the hardware link, including its shape, dimensions, and density. @@ -970,56 +972,52 @@ def compute_mass_and_inertia( - inertia: The computed inertia tensor of the hardware link. """ - mass, inertia = jax.lax.switch( - hw_link_metadata.shape, - [ - HwLinkMetadata._box, - HwLinkMetadata._cylinder, - HwLinkMetadata._sphere, - ], - hw_link_metadata.dims, - hw_link_metadata.density, - ) - return mass, inertia + def box(dims, density) -> tuple[jtp.Float, jtp.Matrix]: + lx, ly, lz = dims - @staticmethod - def _box(dims, density) -> tuple[jtp.Float, jtp.Matrix]: - lx, ly, lz = dims + mass = density * lx * ly * lz - mass = density * lx * ly * lz + inertia = jnp.array( + [ + [mass * (ly**2 + lz**2) / 12, 0, 0], + [0, mass * (lx**2 + lz**2) / 12, 0], + [0, 0, mass * (lx**2 + ly**2) / 12], + ] + ) + return mass, inertia - inertia = jnp.array( - [ - [mass * (ly**2 + lz**2) / 12, 0, 0], - [0, mass * (lx**2 + lz**2) / 12, 0], - [0, 0, mass * (lx**2 + ly**2) / 12], - ] - ) - return mass, inertia + def cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]: + r, l, _ = dims - @staticmethod - def _cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]: - r, l, _ = dims + mass = density * (jnp.pi * r**2 * l) - mass = density * (jnp.pi * r**2 * l) + inertia = jnp.array( + [ + [mass * (3 * r**2 + l**2) / 12, 0, 0], + [0, mass * (3 * r**2 + l**2) / 12, 0], + [0, 0, mass * (r**2) / 2], + ] + ) - inertia = jnp.array( - [ - [mass * (3 * r**2 + l**2) / 12, 0, 0], - [0, mass * (3 * r**2 + l**2) / 12, 0], - [0, 0, mass * (r**2) / 2], - ] - ) + return mass, inertia - return mass, inertia + def sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]: + r = dims[0] - @staticmethod - def _sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]: - r = dims[0] + mass = density * (4 / 3 * jnp.pi * r**3) + + inertia = jnp.eye(3) * (2 / 5 * mass * r**2) + + return mass, inertia - mass = density * (4 / 3 * jnp.pi * r**3) + def compute_mass_inertia(shape_idx, dims, density): + return jax.lax.switch(shape_idx, (box, cylinder, sphere), dims, density) - inertia = jnp.eye(3) * (2 / 5 * mass * r**2) + mass, inertia = jax.vmap(compute_mass_inertia)( + jnp.array(shape_types), + hw_link_metadata.dims, + hw_link_metadata.density, + ) return mass, inertia From c5caeb2a34e80ea459080c0442b48c509aec1cb2 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 May 2025 15:36:50 +0100 Subject: [PATCH 09/19] Refactor handling of unsupported links --- src/jaxsim/api/kin_dyn_parameters.py | 131 +++++++++++---------------- src/jaxsim/api/model.py | 36 +++++++- 2 files changed, 86 insertions(+), 81 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index f28d7b59c..7fb7d00a6 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1069,6 +1069,7 @@ def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix: @staticmethod def apply_scaling( has_joints: bool, + scale_vector: jtp.Vector, hw_metadata: HwLinkMetadata, scaling_factors: ScalingFactors, ) -> HwLinkMetadata: @@ -1077,6 +1078,7 @@ def apply_scaling( Args: has_joints: A boolean indicating if the model has joints. + scale_vector: The scaling vector to apply. hw_metadata: the original HwLinkMetadata object. scaling_factors: the scaling factors to apply. @@ -1084,96 +1086,73 @@ def apply_scaling( A new HwLinkMetadata object with updated parameters. """ - def unsupported_case(hw_metadata, scaling_factors): - - # ======================== - # Handle unsupported links - # ======================== - - # Return the metadata unchanged for unsupported links - return hw_metadata - - def supported_case(hw_metadata, scaling_factors): - - # ================================= - # Update the kinematics of the link - # ================================= - - # Get the nominal transforms - L_H_G = hw_metadata.L_H_G - L_H_vis = hw_metadata.L_H_vis - L_H_pre_array = hw_metadata.L_H_pre - L_H_pre_mask = hw_metadata.L_H_pre_mask + # ================================= + # Update the kinematics of the link + # ================================= - # Compute the 3D scaling vector - scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( - hw_metadata.shape, scaling_factors.dims - ) + # Get the nominal transforms + L_H_G = hw_metadata.L_H_G + L_H_vis = hw_metadata.L_H_vis + L_H_pre_array = hw_metadata.L_H_pre + L_H_pre_mask = hw_metadata.L_H_pre_mask - # Express the transforms in the G frame - G_H_L = jaxsim.math.Transform.inverse(L_H_G) - G_H_vis = G_H_L @ L_H_vis + # Express the transforms in the G frame + G_H_L = jaxsim.math.Transform.inverse(L_H_G) + G_H_vis = G_H_L @ L_H_vis - G_H_pre_array = ( - jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array) - if has_joints - else L_H_pre_array - ) + G_H_pre_array = ( + jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array) + if has_joints + else L_H_pre_array + ) - # Apply the scaling to the position vectors - G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) + # Apply the scaling to the position vectors + G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) - # Apply scaling to the position vectors in G_H_pre_array based on the mask - G_H̅_pre_array = ( - G_H_pre_array.at[:, :3, 3].set( - jnp.where( - L_H_pre_mask[:, None], - scale_vector[None, :] * G_H_pre_array[:, :3, 3], - G_H_pre_array[:, :3, 3], - ) + # Apply scaling to the position vectors in G_H_pre_array based on the mask + G_H̅_pre_array = ( + G_H_pre_array.at[:, :3, 3].set( + jnp.where( + L_H_pre_mask[:, None], + scale_vector[None, :] * G_H_pre_array[:, :3, 3], + G_H_pre_array[:, :3, 3], ) - if has_joints - else G_H_pre_array - ) - - # Get back to the link frame - L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3]) - L_H̅_vis = L_H̅_G @ G_H̅_vis - L_H̅_pre_array = ( - jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array) - if has_joints - else G_H̅_pre_array ) + if has_joints + else G_H_pre_array + ) - # =========================== - # Update the shape parameters - # =========================== + # Get back to the link frame + L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3]) + L_H̅_vis = L_H̅_G @ G_H̅_vis + L_H̅_pre_array = ( + jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array) + if has_joints + else G_H̅_pre_array + ) - updated_dims = hw_metadata.dims * scaling_factors.dims + # =========================== + # Update the shape parameters + # =========================== - # ============================= - # Scale the density of the link - # ============================= + updated_dims = hw_metadata.dims * scaling_factors.dims - updated_density = hw_metadata.density * scaling_factors.density + # ============================= + # Scale the density of the link + # ============================= - # ============================= - # Return updated HwLinkMetadata - # ============================= + updated_density = hw_metadata.density * scaling_factors.density - return hw_metadata.replace( - dims=updated_dims, - density=updated_density, - L_H_G=L_H̅_G, - L_H_vis=L_H̅_vis, - L_H_pre=L_H̅_pre_array, - ) + # ============================= + # Return updated HwLinkMetadata + # ============================= - # Use jax.lax.cond to handle unsupported links - return jax.lax.cond( - hw_metadata.shape == LinkParametrizableShape.Unsupported, - lambda: unsupported_case(hw_metadata, scaling_factors), - lambda: supported_case(hw_metadata, scaling_factors), + return hw_metadata.replace( + dims=updated_dims, + density=updated_density, + L_H_G=L_H̅_G, + L_H_vis=L_H̅_vis, + L_H_pre=L_H̅_pre_array, ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 7d3036411..262b0b4ba 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2358,20 +2358,46 @@ def update_hw_parameters( has_joints = model.number_of_joints() > 0 + supported_mask = hw_link_metadata.shape != LinkParametrizableShape.Unsupported + + supported_metadata = jax.tree.map(lambda l: l[supported_mask], hw_link_metadata) + + supported_scaling_factors = jax.tree.map( + lambda l: l[supported_mask], scaling_factors + ) + scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( - hw_link_metadata.shape, scaling_factors + supported_metadata.shape, supported_scaling_factors ) # Apply scaling to hw_link_metadata using vmap - updated_hw_link_metadata = jax.vmap(HwLinkMetadata.apply_scaling, in_axes=(None,))( + scaled_hw_link_metadata_supported = jax.vmap( + HwLinkMetadata.apply_scaling, in_axes=(None,) + )( has_joints, - hw_metadata=hw_link_metadata, + scale_vector=scale_vector, + hw_metadata=supported_metadata, scaling_factors=scaling_factors, ) + # Helper function to merge pytrees leaf-wise with boolean mask + def merge_pytree_by_mask(scaled_pytree, original_pytree, mask): + + def merge_leaf(scaled_leaf, original_leaf): + mask_shape = (mask.shape[0],) + (1,) * (scaled_leaf.ndim - 1) + mask_broadcasted = mask.reshape(mask_shape) + + return jnp.where(mask_broadcasted, scaled_leaf, original_leaf) + + return jax.tree.map(merge_leaf, scaled_pytree, original_pytree) + + updated_hw_link_metadata = merge_pytree_by_mask( + scaled_hw_link_metadata_supported, hw_link_metadata, supported_mask + ) + # Compute mass and inertia once and unpack the results - m_updated, I_com_updated = jax.vmap(HwLinkMetadata.compute_mass_and_inertia)( - updated_hw_link_metadata + m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia( + hw_link_metadata.shape, updated_hw_link_metadata ) # Rotate the inertia tensor at CoM with the link orientation, and store From 7d4efbb484b60e21d000fa2d20f0dfbc3490c410 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 May 2025 15:37:20 +0100 Subject: [PATCH 10/19] Update hardware parametrization tests --- tests/test_api_model_hw_parametrization.py | 52 ++++++++++------------ 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index 5cdb1243c..098565894 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -95,38 +95,34 @@ def test_model_scaling_against_rod( ) # Compare hardware parameters of the scaled JaxSim model with the pre-scaled JaxSim model - for link_idx, link_name in enumerate(jaxsim_model_garpez.link_names()): - scaled_metadata = jax.tree_util.tree_map( - lambda x, link_idx=link_idx: x[link_idx], - updated_model.kin_dyn_parameters.hw_link_metadata, - ) - pre_scaled_metadata = jax.tree_util.tree_map( - lambda x, link_idx=link_idx: x[link_idx], - jaxsim_model_garpez_scaled.kin_dyn_parameters.hw_link_metadata, - ) + scaled_metadata = updated_model.kin_dyn_parameters.hw_link_metadata - # Compare shape dimensions - assert jnp.allclose(scaled_metadata.dims, pre_scaled_metadata.dims, atol=1e-6) + pre_scaled_metadata = jaxsim_model_garpez_scaled.kin_dyn_parameters.hw_link_metadata - # Compare mass - scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata) - pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( - pre_scaled_metadata - ) - assert scaled_mass == pytest.approx(pre_scaled_mass, abs=1e-6) + # Compare shape dimensions + assert jnp.allclose(scaled_metadata.dims, pre_scaled_metadata.dims, atol=1e-6) - # Compare inertia tensors - _, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata) - _, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( - pre_scaled_metadata - ) - assert jnp.allclose(scaled_inertia, pre_scaled_inertia, atol=1e-6) + # Compare mass + scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( + scaled_metadata.shape, scaled_metadata + ) + pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( + pre_scaled_metadata.shape, pre_scaled_metadata + ) + assert scaled_mass == pytest.approx(pre_scaled_mass, abs=1e-6) - # Compare transformations - assert jnp.allclose(scaled_metadata.L_H_G, pre_scaled_metadata.L_H_G, atol=1e-6) - assert jnp.allclose( - scaled_metadata.L_H_vis, pre_scaled_metadata.L_H_vis, atol=1e-6 - ) + # Compare inertia tensors + _, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( + scaled_metadata.shape, scaled_metadata + ) + _, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( + pre_scaled_metadata.shape, pre_scaled_metadata + ) + assert jnp.allclose(scaled_inertia, pre_scaled_inertia, atol=1e-6) + + # Compare transformations + assert jnp.allclose(scaled_metadata.L_H_G, pre_scaled_metadata.L_H_G, atol=1e-6) + assert jnp.allclose(scaled_metadata.L_H_vis, pre_scaled_metadata.L_H_vis, atol=1e-6) def test_update_hw_parameters_vmap( From 510fa2b5d63277f46897e154a948218d5b0919ab Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 May 2025 15:43:56 +0100 Subject: [PATCH 11/19] Clean up imports --- src/jaxsim/api/model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 262b0b4ba..94e6cae96 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -10,13 +10,12 @@ import jax import jax.numpy as jnp import jax_dataclasses +import numpy as np import rod -import rod.urdf from jax_dataclasses import Static from rod.urdf.exporter import UrdfExporter import jaxsim.api as js -import jaxsim.exceptions import jaxsim.terrain import jaxsim.typing as jtp from jaxsim import logging @@ -508,8 +507,6 @@ def export_updated_model(self) -> str: This method is not meant to be used in JIT-compiled functions. """ - import numpy as np - if isinstance(jnp.zeros(0), jax.core.Tracer): raise RuntimeError("This method cannot be used in JIT-compiled functions") From 09d539c439d1e598e7eab6ff5c21cc0faa7509fc Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 8 Sep 2025 15:08:11 +0200 Subject: [PATCH 12/19] Add more explanatory comment on shape index --- src/jaxsim/api/kin_dyn_parameters.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 7fb7d00a6..32f6dda89 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1043,6 +1043,9 @@ def _convert_scaling_to_3d_vector( """ # Index mapping for each shape type (shape_type x 3 dims) + # Box: [lx, ly, lz] -> [0, 1, 2] + # Cylinder: [r, r, l] -> [0, 0, 1] + # Sphere: [r, r, r] -> [0, 0, 0] shape_indices = jnp.array( [ [0, 1, 2], # Box From 94ec899393594d5dcb85978e0477146b0fa3727d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 8 Sep 2025 15:08:45 +0200 Subject: [PATCH 13/19] Move computation of `scale_vector` inside `apply_scaling` --- src/jaxsim/api/kin_dyn_parameters.py | 6 ++++-- src/jaxsim/api/model.py | 7 +------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 32f6dda89..d30b0d4dd 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1072,7 +1072,6 @@ def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix: @staticmethod def apply_scaling( has_joints: bool, - scale_vector: jtp.Vector, hw_metadata: HwLinkMetadata, scaling_factors: ScalingFactors, ) -> HwLinkMetadata: @@ -1081,7 +1080,6 @@ def apply_scaling( Args: has_joints: A boolean indicating if the model has joints. - scale_vector: The scaling vector to apply. hw_metadata: the original HwLinkMetadata object. scaling_factors: the scaling factors to apply. @@ -1089,6 +1087,10 @@ def apply_scaling( A new HwLinkMetadata object with updated parameters. """ + scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( + hw_metadata.shape, scaling_factors + ) + # ================================= # Update the kinematics of the link # ================================= diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 94e6cae96..9e5aab32a 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2363,18 +2363,13 @@ def update_hw_parameters( lambda l: l[supported_mask], scaling_factors ) - scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( - supported_metadata.shape, supported_scaling_factors - ) - # Apply scaling to hw_link_metadata using vmap scaled_hw_link_metadata_supported = jax.vmap( HwLinkMetadata.apply_scaling, in_axes=(None,) )( has_joints, - scale_vector=scale_vector, hw_metadata=supported_metadata, - scaling_factors=scaling_factors, + scaling_factors=supported_scaling_factors, ) # Helper function to merge pytrees leaf-wise with boolean mask From ee2df016d98356e6fa5916762ce65156e9992aaf Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 8 Sep 2025 16:24:15 +0200 Subject: [PATCH 14/19] Rename `HwLinkMetadata` attributes --- src/jaxsim/api/kin_dyn_parameters.py | 16 +++--- src/jaxsim/api/model.py | 24 ++++----- tests/test_api_model_hw_parametrization.py | 57 ++++++++++++---------- 3 files changed, 50 insertions(+), 47 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index d30b0d4dd..ed66f5ea1 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -923,9 +923,9 @@ class HwLinkMetadata(JaxsimDataclass): Class storing the hardware parameters of a link. Attributes: - shape: The shape of the link. + link_shape: The shape of the link. 0 = box, 1 = cylinder, 2 = sphere, -1 = unsupported. - dims: The dimensions of the link. + geometry: The dimensions of the link. box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0]. density: The density of the link. L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G. @@ -934,8 +934,8 @@ class HwLinkMetadata(JaxsimDataclass): L_H_pre: The homogeneous transforms for child joints. """ - _shape: Static[tuple[int]] - dims: jtp.Vector + _link_shape: Static[list[float]] + geometry: jtp.Vector density: jtp.Float L_H_G: jtp.Matrix L_H_vis: jtp.Matrix @@ -1015,7 +1015,7 @@ def compute_mass_inertia(shape_idx, dims, density): mass, inertia = jax.vmap(compute_mass_inertia)( jnp.array(shape_types), - hw_link_metadata.dims, + hw_link_metadata.geometry, hw_link_metadata.density, ) @@ -1088,7 +1088,7 @@ def apply_scaling( """ scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( - hw_metadata.shape, scaling_factors + hw_metadata.link_shape, scaling_factors ) # ================================= @@ -1140,7 +1140,7 @@ def apply_scaling( # Update the shape parameters # =========================== - updated_dims = hw_metadata.dims * scaling_factors.dims + updated_geoms = hw_metadata.geometry * scaling_factors.dims # ============================= # Scale the density of the link @@ -1153,7 +1153,7 @@ def apply_scaling( # ============================= return hw_metadata.replace( - dims=updated_dims, + geometry=updated_geoms, density=updated_density, L_H_G=L_H̅_G, L_H_vis=L_H̅_vis, diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 9e5aab32a..a67cffd55 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -361,8 +361,8 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: "Skipping for hardware parametrization." ) return HwLinkMetadata( - shape=jnp.array([]), - dims=jnp.array([]), + link_shape=[], + geometry=jnp.array([]), density=jnp.array([]), L_H_G=jnp.array([]), L_H_vis=jnp.array([]), @@ -397,7 +397,7 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: # Initialize lists to collect metadata for all links shapes = [] - dims = [] + geoms = [] densities = [] L_H_Gs = [] L_H_vises = [] @@ -447,17 +447,17 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: if isinstance(geometry, rod.Box): lx, ly, lz = geometry.size density = mass / (lx * ly * lz) - dims.append([lx, ly, lz]) + geoms.append([lx, ly, lz]) shapes.append(LinkParametrizableShape.Box) elif isinstance(geometry, rod.Sphere): r = geometry.radius density = mass / (4 / 3 * jnp.pi * r**3) - dims.append([r, 0, 0]) + geoms.append([r, 0, 0]) shapes.append(LinkParametrizableShape.Sphere) elif isinstance(geometry, rod.Cylinder): r, l = geometry.radius, geometry.length density = mass / (jnp.pi * r**2 * l) - dims.append([r, l, 0]) + geoms.append([r, l, 0]) shapes.append(LinkParametrizableShape.Cylinder) else: logging.debug( @@ -487,8 +487,8 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: # Stack collected data into JAX arrays return HwLinkMetadata( - _shape=shapes, - dims=jnp.array(dims, dtype=float), + _link_shape=shapes, + geometry=jnp.array(geoms, dtype=float), density=jnp.array(densities, dtype=float), L_H_G=jnp.array(L_H_Gs, dtype=float), L_H_vis=jnp.array(L_H_vises, dtype=float), @@ -558,8 +558,8 @@ def export_updated_model(self) -> str: ) # Update visual shape - shape = hw_metadata.shape[link_index] - dims = hw_metadata.dims[link_index] + shape = hw_metadata.link_shape[link_index] + dims = hw_metadata.geometry[link_index] if shape == LinkParametrizableShape.Box: links_dict[link_name].visual.geometry.box.size = dims.tolist() elif shape == LinkParametrizableShape.Sphere: @@ -2355,7 +2355,7 @@ def update_hw_parameters( has_joints = model.number_of_joints() > 0 - supported_mask = hw_link_metadata.shape != LinkParametrizableShape.Unsupported + supported_mask = hw_link_metadata.link_shape != LinkParametrizableShape.Unsupported supported_metadata = jax.tree.map(lambda l: l[supported_mask], hw_link_metadata) @@ -2389,7 +2389,7 @@ def merge_leaf(scaled_leaf, original_leaf): # Compute mass and inertia once and unpack the results m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia( - hw_link_metadata.shape, updated_hw_link_metadata + hw_link_metadata.link_shape, updated_hw_link_metadata ) # Rotate the inertia tensor at CoM with the link orientation, and store diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index 098565894..d418c5c20 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -30,30 +30,31 @@ def test_update_hw_link_parameters(jaxsim_model_garpez: js.model.JaxSimModel): density=jnp.ones(4), ) - # Update the model using the scaling factors - updated_model = js.model.update_hw_parameters(model, scaling_parameters) - - # Compare updated hardware parameters - for link_idx, link_name in enumerate(model.link_names()): - updated_metadata = jax.tree_util.tree_map( - lambda x, link_idx=link_idx: x[link_idx], - updated_model.kin_dyn_parameters.hw_link_metadata, - ) - initial_metadata_link = jax.tree_util.tree_map( - lambda x, link_idx=link_idx: x[link_idx], initial_metadata - ) + with jax.disable_jit(False): + # Update the model using the scaling factors + updated_model = js.model.update_hw_parameters(model, scaling_parameters) + + # Compare updated hardware parameters + for link_idx, link_name in enumerate(model.link_names()): + updated_metadata = jax.tree_util.tree_map( + lambda x, link_idx=link_idx: x[link_idx], + updated_model.kin_dyn_parameters.hw_link_metadata, + ) + initial_metadata_link = jax.tree_util.tree_map( + lambda x, link_idx=link_idx: x[link_idx], initial_metadata + ) - # TODO: Compute the 3D scaling vector - # scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( - # initial_metadata_link.shape, scaling_parameters.dims[link_idx] - # ) + # TODO: Compute the 3D scaling vector + # scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( + # initial_metadata_link.shape, scaling_parameters.dims[link_idx] + # ) - # Compare shape dimensions - assert jnp.allclose( - updated_metadata.dims, - initial_metadata_link.dims * scaling_parameters.dims[link_idx], - atol=1e-6, - ), f"Mismatch in dimensions for link {link_name}: expected {initial_metadata_link.dims * scaling_parameters.dims[link_idx]}, got {updated_metadata.dims}" + # Compare shape dimensions + assert jnp.allclose( + updated_metadata.geometry, + initial_metadata_link.geometry * scaling_parameters.dims[link_idx], + atol=1e-6, + ), f"Mismatch in dimensions for link {link_name}: expected {initial_metadata_link.geometry * scaling_parameters.dims[link_idx]}, got {updated_metadata.geometry}" @pytest.mark.parametrize( @@ -100,23 +101,25 @@ def test_model_scaling_against_rod( pre_scaled_metadata = jaxsim_model_garpez_scaled.kin_dyn_parameters.hw_link_metadata # Compare shape dimensions - assert jnp.allclose(scaled_metadata.dims, pre_scaled_metadata.dims, atol=1e-6) + assert jnp.allclose( + scaled_metadata.geometry, pre_scaled_metadata.geometry, atol=1e-6 + ) # Compare mass scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( - scaled_metadata.shape, scaled_metadata + scaled_metadata.link_shape, scaled_metadata ) pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( - pre_scaled_metadata.shape, pre_scaled_metadata + pre_scaled_metadata.link_shape, pre_scaled_metadata ) assert scaled_mass == pytest.approx(pre_scaled_mass, abs=1e-6) # Compare inertia tensors _, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( - scaled_metadata.shape, scaled_metadata + scaled_metadata.link_shape, scaled_metadata ) _, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( - pre_scaled_metadata.shape, pre_scaled_metadata + pre_scaled_metadata.link_shape, pre_scaled_metadata ) assert jnp.allclose(scaled_inertia, pre_scaled_inertia, atol=1e-6) From 1ab34561a4e3f8ec7fb6afbc780b72b93f42d512 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 8 Sep 2025 17:02:33 +0200 Subject: [PATCH 15/19] Restore unsupported link handling with `jax.vmap` --- src/jaxsim/api/kin_dyn_parameters.py | 13 ++------ src/jaxsim/api/model.py | 44 +++++++++------------------- 2 files changed, 17 insertions(+), 40 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index ed66f5ea1..ed30fe633 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -934,7 +934,7 @@ class HwLinkMetadata(JaxsimDataclass): L_H_pre: The homogeneous transforms for child joints. """ - _link_shape: Static[list[float]] + link_shape: jtp.Vector geometry: jtp.Vector density: jtp.Float L_H_G: jtp.Matrix @@ -942,13 +942,6 @@ class HwLinkMetadata(JaxsimDataclass): L_H_pre_mask: jtp.Vector L_H_pre: jtp.Matrix - @property - def shape(self) -> int: - """ - Return the shape of the link. - """ - return np.array(self._shape) - @staticmethod def compute_mass_and_inertia( shape_types: jtp.Array, @@ -1014,7 +1007,7 @@ def compute_mass_inertia(shape_idx, dims, density): return jax.lax.switch(shape_idx, (box, cylinder, sphere), dims, density) mass, inertia = jax.vmap(compute_mass_inertia)( - jnp.array(shape_types), + shape_types, hw_link_metadata.geometry, hw_link_metadata.density, ) @@ -1058,7 +1051,7 @@ def _convert_scaling_to_3d_vector( per_link_indices = shape_indices[shape_types] # Gather dims per link according to per_link_indices - return jnp.take_along_axis(scaling_factors.dims, per_link_indices, axis=1) + return scaling_factors.dims[per_link_indices.squeeze()] @staticmethod def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix: diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index a67cffd55..0d04cd4a7 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -361,7 +361,7 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: "Skipping for hardware parametrization." ) return HwLinkMetadata( - link_shape=[], + link_shape=jnp.array([]), geometry=jnp.array([]), density=jnp.array([]), L_H_G=jnp.array([]), @@ -487,7 +487,7 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: # Stack collected data into JAX arrays return HwLinkMetadata( - _link_shape=shapes, + link_shape=jnp.array(shapes, dtype=int), geometry=jnp.array(geoms, dtype=float), density=jnp.array(densities, dtype=float), L_H_G=jnp.array(L_H_Gs, dtype=float), @@ -2355,37 +2355,21 @@ def update_hw_parameters( has_joints = model.number_of_joints() > 0 - supported_mask = hw_link_metadata.link_shape != LinkParametrizableShape.Unsupported - - supported_metadata = jax.tree.map(lambda l: l[supported_mask], hw_link_metadata) - - supported_scaling_factors = jax.tree.map( - lambda l: l[supported_mask], scaling_factors + supported_case = lambda hw_metadata, scaling_factors: HwLinkMetadata.apply_scaling( + hw_metadata=hw_metadata, scaling_factors=scaling_factors, has_joints=has_joints ) + unsupported_case = lambda hw_metadata, scaling_factors: hw_metadata # Apply scaling to hw_link_metadata using vmap - scaled_hw_link_metadata_supported = jax.vmap( - HwLinkMetadata.apply_scaling, in_axes=(None,) - )( - has_joints, - hw_metadata=supported_metadata, - scaling_factors=supported_scaling_factors, - ) - - # Helper function to merge pytrees leaf-wise with boolean mask - def merge_pytree_by_mask(scaled_pytree, original_pytree, mask): - - def merge_leaf(scaled_leaf, original_leaf): - mask_shape = (mask.shape[0],) + (1,) * (scaled_leaf.ndim - 1) - mask_broadcasted = mask.reshape(mask_shape) - - return jnp.where(mask_broadcasted, scaled_leaf, original_leaf) - - return jax.tree.map(merge_leaf, scaled_pytree, original_pytree) - - updated_hw_link_metadata = merge_pytree_by_mask( - scaled_hw_link_metadata_supported, hw_link_metadata, supported_mask - ) + updated_hw_link_metadata = jax.vmap( + lambda hw_metadata, multipliers: jax.lax.cond( + hw_metadata.link_shape == LinkParametrizableShape.Unsupported, + unsupported_case, + supported_case, + hw_metadata, + multipliers, + ) + )(hw_link_metadata, scaling_factors) # Compute mass and inertia once and unpack the results m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia( From 44d2d141f154aee589350756a642800e9d904fa8 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 8 Sep 2025 17:15:25 +0200 Subject: [PATCH 16/19] Remove extra arg in `compute_mass_and_inertia` --- src/jaxsim/api/kin_dyn_parameters.py | 4 +--- src/jaxsim/api/model.py | 2 +- tests/test_api_model_hw_parametrization.py | 16 ++++------------ 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index ed30fe633..72282b216 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -944,7 +944,6 @@ class HwLinkMetadata(JaxsimDataclass): @staticmethod def compute_mass_and_inertia( - shape_types: jtp.Array, hw_link_metadata: HwLinkMetadata, ) -> tuple[jtp.Float, jtp.Matrix]: """ @@ -955,7 +954,6 @@ def compute_mass_and_inertia( by using shape-specific methods. Args: - shape_types: The shape types of the link (e.g., box, sphere, cylinder). hw_link_metadata: Metadata describing the hardware link, including its shape, dimensions, and density. @@ -1007,7 +1005,7 @@ def compute_mass_inertia(shape_idx, dims, density): return jax.lax.switch(shape_idx, (box, cylinder, sphere), dims, density) mass, inertia = jax.vmap(compute_mass_inertia)( - shape_types, + hw_link_metadata.link_shape, hw_link_metadata.geometry, hw_link_metadata.density, ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 0d04cd4a7..d24fd2d58 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2373,7 +2373,7 @@ def update_hw_parameters( # Compute mass and inertia once and unpack the results m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia( - hw_link_metadata.link_shape, updated_hw_link_metadata + updated_hw_link_metadata ) # Rotate the inertia tensor at CoM with the link orientation, and store diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index d418c5c20..6691328d2 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -106,21 +106,13 @@ def test_model_scaling_against_rod( ) # Compare mass - scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( - scaled_metadata.link_shape, scaled_metadata - ) - pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( - pre_scaled_metadata.link_shape, pre_scaled_metadata - ) + scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata) + pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(pre_scaled_metadata) assert scaled_mass == pytest.approx(pre_scaled_mass, abs=1e-6) # Compare inertia tensors - _, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( - scaled_metadata.link_shape, scaled_metadata - ) - _, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( - pre_scaled_metadata.link_shape, pre_scaled_metadata - ) + _, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata) + _, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(pre_scaled_metadata) assert jnp.allclose(scaled_inertia, pre_scaled_inertia, atol=1e-6) # Compare transformations From b133399520364f0b5da08f9e408c9b05fc8707af Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 8 Sep 2025 17:17:47 +0200 Subject: [PATCH 17/19] Rename `shape_types` with `link_shapes` --- src/jaxsim/api/kin_dyn_parameters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 72282b216..9aa78a8ba 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1014,13 +1014,13 @@ def compute_mass_inertia(shape_idx, dims, density): @staticmethod def _convert_scaling_to_3d_vector( - shape_types: jtp.Int, scaling_factors: jtp.Vector + link_shapes: jtp.Int, scaling_factors: jtp.Vector ) -> jtp.Vector: """ Convert scaling factors for specific shape dimensions into a 3D scaling vector. Args: - shape_types: The shape_types of the link (e.g., box, sphere, cylinder). + link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder). scaling_factors: The scaling factors for the shape dimensions. Returns: @@ -1033,7 +1033,7 @@ def _convert_scaling_to_3d_vector( - Sphere: [r, r, r] """ - # Index mapping for each shape type (shape_type x 3 dims) + # Index mapping for each shape type (link_shapes x 3 dims) # Box: [lx, ly, lz] -> [0, 1, 2] # Cylinder: [r, r, l] -> [0, 0, 1] # Sphere: [r, r, r] -> [0, 0, 0] @@ -1046,7 +1046,7 @@ def _convert_scaling_to_3d_vector( ) # For each link, get the index vector for its shape - per_link_indices = shape_indices[shape_types] + per_link_indices = shape_indices[link_shapes] # Gather dims per link according to per_link_indices return scaling_factors.dims[per_link_indices.squeeze()] From ce4ccfdc4b3ebe07a086be5552473bba98d80229 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 9 Sep 2025 10:48:07 +0200 Subject: [PATCH 18/19] Remove leftover `jax.disable_jit` --- tests/test_api_model_hw_parametrization.py | 47 +++++++++++----------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index 6691328d2..b7d4d666a 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -30,31 +30,30 @@ def test_update_hw_link_parameters(jaxsim_model_garpez: js.model.JaxSimModel): density=jnp.ones(4), ) - with jax.disable_jit(False): - # Update the model using the scaling factors - updated_model = js.model.update_hw_parameters(model, scaling_parameters) - - # Compare updated hardware parameters - for link_idx, link_name in enumerate(model.link_names()): - updated_metadata = jax.tree_util.tree_map( - lambda x, link_idx=link_idx: x[link_idx], - updated_model.kin_dyn_parameters.hw_link_metadata, - ) - initial_metadata_link = jax.tree_util.tree_map( - lambda x, link_idx=link_idx: x[link_idx], initial_metadata - ) + # Update the model using the scaling factors + updated_model = js.model.update_hw_parameters(model, scaling_parameters) - # TODO: Compute the 3D scaling vector - # scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( - # initial_metadata_link.shape, scaling_parameters.dims[link_idx] - # ) - - # Compare shape dimensions - assert jnp.allclose( - updated_metadata.geometry, - initial_metadata_link.geometry * scaling_parameters.dims[link_idx], - atol=1e-6, - ), f"Mismatch in dimensions for link {link_name}: expected {initial_metadata_link.geometry * scaling_parameters.dims[link_idx]}, got {updated_metadata.geometry}" + # Compare updated hardware parameters + for link_idx, link_name in enumerate(model.link_names()): + updated_metadata = jax.tree_util.tree_map( + lambda x, link_idx=link_idx: x[link_idx], + updated_model.kin_dyn_parameters.hw_link_metadata, + ) + initial_metadata_link = jax.tree_util.tree_map( + lambda x, link_idx=link_idx: x[link_idx], initial_metadata + ) + + # TODO: Compute the 3D scaling vector + # scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( + # initial_metadata_link.shape, scaling_parameters.dims[link_idx] + # ) + + # Compare shape dimensions + assert jnp.allclose( + updated_metadata.geometry, + initial_metadata_link.geometry * scaling_parameters.dims[link_idx], + atol=1e-6, + ), f"Mismatch in dimensions for link {link_name}: expected {initial_metadata_link.geometry * scaling_parameters.dims[link_idx]}, got {updated_metadata.geometry}" @pytest.mark.parametrize( From f8b58c1ce31ecacfd9cc53518ceedfbd45afb62a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 9 Sep 2025 13:29:21 +0200 Subject: [PATCH 19/19] Update variable names --- src/jaxsim/api/kin_dyn_parameters.py | 8 ++++---- src/jaxsim/api/model.py | 2 +- tests/test_api_model_hw_parametrization.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 6acf287d2..17e25e607 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1054,7 +1054,7 @@ def _convert_scaling_to_3d_vector( @staticmethod def compute_contact_points( original_contact_params: jtp.Vector, - shape_types: jtp.Vector, + link_shapes: jtp.Vector, original_com_positions: jtp.Vector, updated_com_positions: jtp.Vector, scaling_factors: ScalingFactors, @@ -1065,7 +1065,7 @@ def compute_contact_points( Args: original_contact_params: The original contact parameters. - shape_types: The shape types of the links (e.g., box, sphere, cylinder). + link_shapes: The shape types of the links (e.g., box, sphere, cylinder). original_com_positions: The original center of mass positions of the links. updated_com_positions: The updated center of mass positions of the links. scaling_factors: The scaling factors for the link dimensions. @@ -1083,7 +1083,7 @@ def compute_contact_points( ) # Extract the shape types of the parent links. - parent_shape_types = jnp.array(shape_types[parent_link_indices]) + parent_link_shapes = jnp.array(link_shapes[parent_link_indices]) def sphere(parent_idx, L_p_C): r = scaling_factors.dims[parent_idx][0] @@ -1108,7 +1108,7 @@ def box(parent_idx, L_p_C): shape_idx, (box, cylinder, sphere), parent_idx, L_p_C ) )( - parent_shape_types, + parent_link_shapes, parent_link_indices, L_p_Ci, ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index eb30436eb..c3ac2fbc5 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2396,7 +2396,7 @@ def update_hw_parameters( # Compute the contact parameters points = HwLinkMetadata.compute_contact_points( original_contact_params=kin_dyn_params.contact_parameters, - shape_types=updated_hw_link_metadata.shape, + link_shapes=updated_hw_link_metadata.link_shape, original_com_positions=link_parameters.center_of_mass, updated_com_positions=updated_link_parameters.center_of_mass, scaling_factors=scaling_factors, diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index a41317228..9d389710a 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -402,7 +402,7 @@ def test_hw_parameters_collision_scaling( scaling_factor = 5.0 # Define the nominal radius of the sphere - nominal_height = model.kin_dyn_parameters.hw_link_metadata.dims[0, 2] + nominal_height = model.kin_dyn_parameters.hw_link_metadata.geometry[0, 2] # Define scaling parameters scaling_parameters = ScalingFactors(