diff --git a/.gitignore b/.gitignore index 05520e402..1448c0216 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,6 @@ src/jaxsim/_version.py # data .mp4 .png + +# macOS +.DS_Store diff --git a/docs/guide/configuration.rst b/docs/guide/configuration.rst index 4061b30ff..d1ea88041 100644 --- a/docs/guide/configuration.rst +++ b/docs/guide/configuration.rst @@ -9,21 +9,10 @@ Collision Dynamics Environment variables starting with ``JAXSIM_COLLISION_`` are used to configure collision dynamics. The available variables are: -- ``JAXSIM_COLLISION_SPHERE_POINTS``: Specifies the number of collision points to approximate the sphere. - - *Default:* ``50``. - -- ``JAXSIM_COLLISION_MESH_ENABLED``: Enables or disables mesh-based collision detection. +- ``JAXSIM_COLLISION_ENABLE_CYLINDER``: Enables collision dynamics for cylindrical geometries. *Default:* ``False``. -- ``JAXSIM_COLLISION_USE_BOTTOM_ONLY``: Limits collision detection to only the bottom half of the box or sphere. - - *Default:* ``False``. - -.. note:: - The bottom half is defined as the half of the box or sphere with the lowest z-coordinate in the collision link frame. - Testing ~~~~~~~ diff --git a/environment.yml b/environment.yml index 6cec4bb33..446c77766 100644 --- a/environment.yml +++ b/environment.yml @@ -15,7 +15,6 @@ dependencies: - pptree - qpax - rod >= 0.3.3 - - trimesh - typing_extensions # python<3.12 # ==================================== # Optional dependencies from setup.cfg diff --git a/examples/jaxsim_as_physics_engine_advanced.ipynb b/examples/jaxsim_as_physics_engine_advanced.ipynb index e74dc6e23..0faa12f23 100644 --- a/examples/jaxsim_as_physics_engine_advanced.ipynb +++ b/examples/jaxsim_as_physics_engine_advanced.ipynb @@ -130,10 +130,7 @@ "# JaxSim currently only supports collisions between points attached to bodies\n", "# and a ground surface modeled as a heightmap sampled from a smooth function.\n", "# While this approach is universal as it applies to generic meshes, the number\n", - "# of considered points greatly affects the performance. Spheres, by default,\n", - "# are discretized with 250 points. It's too much for this simple example.\n", - "# This number can be decreased with the following environment variable.\n", - "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\"" + "# of considered points greatly affects the performance." ] }, { diff --git a/pixi.lock b/pixi.lock index dc7587d92..4ecf17197 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f187b9d350d9c9d6383de3c3bc8d677d1e4b8199224f28d3dbfcfbc80357a68e -size 513548 +oid sha256:04de87ab540a1d618bddf653a4b55906fcc7f01fbd7ccdb454079ca6d5930342 +size 512946 diff --git a/pyproject.toml b/pyproject.toml index 3353e7148..6b011c1ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,6 @@ dependencies = [ "qpax", "rod >= 0.4.1", "typing_extensions ; python_version < '3.12'", - "trimesh", ] [project.optional-dependencies] @@ -223,8 +222,7 @@ jax-dataclasses = "*" pptree = "*" optax = "*" qpax = "*" -rod = ">=0.4.1" -trimesh = "*" +rod = "*" typing_extensions = "*" # # Optional dependencies. diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index b56c75fab..0f359dd9b 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -10,14 +10,14 @@ import jaxsim.typing as jtp from jaxsim import logging from jaxsim.math import Adjoint, Cross, Transform -from jaxsim.rbda.contacts import SoftContacts +from jaxsim.rbda.contacts import SoftContacts, detection from .common import VelRepr @jax.jit @js.common.named_scope -def collidable_point_kinematics( +def contact_point_kinematics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> tuple[jtp.Matrix, jtp.Matrix]: """ @@ -36,18 +36,33 @@ def collidable_point_kinematics( the linear component of the mixed 6D frame velocity. """ - W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( - model=model, - link_transforms=data._link_transforms, - link_velocities=data._link_velocities, + _, _, _, W_p_Ci, W_ṗ_Ci = jax.vmap( + lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: jaxsim.rbda.contacts.common.compute_penetration_data( + model, + shape_transform=shape_transform, + shape_type=shape_type, + shape_size=shape_size, + link_transforms=link_transform, + link_velocities=link_velocity, + ) + )( + model.kin_dyn_parameters.contact_parameters.transform, + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + data._link_transforms[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], + data._link_velocities[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], ) - return W_p_Ci, W_ṗ_Ci + return W_p_Ci, W_ṗ_Ci @jax.jit @js.common.named_scope -def collidable_point_positions( +def contact_point_positions( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ @@ -61,14 +76,14 @@ def collidable_point_positions( The position of the collidable points in the world frame. """ - W_p_Ci, _ = collidable_point_kinematics(model=model, data=data) + W_p_Ci, _ = contact_point_kinematics(model=model, data=data) return W_p_Ci @jax.jit @js.common.named_scope -def collidable_point_velocities( +def contact_point_velocities( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ @@ -82,7 +97,7 @@ def collidable_point_velocities( The 3D velocity of the collidable points. """ - _, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data) + _, W_ṗ_Ci = contact_point_kinematics(model=model, data=data) return W_ṗ_Ci @@ -112,15 +127,15 @@ def in_contact( raise ValueError("One or more link names are not part of the model") # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + indices_of_enabled_contact_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_contact_points ) - parent_link_idx_of_enabled_collidable_points = jnp.array( + parent_link_idx_of_enabled_contact_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + )[indices_of_enabled_contact_points] - W_p_Ci = collidable_point_positions(model=model, data=data) + W_p_Ci = contact_point_positions(model=model, data=data) terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))( W_p_Ci[:, 0], W_p_Ci[:, 1] @@ -136,7 +151,7 @@ def in_contact( links_in_contact = jax.vmap( lambda link_index: jnp.where( - parent_link_idx_of_enabled_collidable_points == link_index, + parent_link_idx_of_enabled_contact_points == link_index, below_terrain, jnp.zeros_like(below_terrain, dtype=bool), ).any() @@ -162,7 +177,7 @@ def estimate_good_contact_parameters( *, standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, - number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + number_of_active_contact_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, max_penetration: jtp.FloatLike | None = None, ) -> jaxsim.rbda.contacts.ContactParamsTypes: @@ -173,7 +188,7 @@ def estimate_good_contact_parameters( model: The model to consider. standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. - number_of_active_collidable_points_steady_state: + number_of_active_contact_points_steady_state: The number of active collidable points in steady state. damping_ratio: The damping ratio. max_penetration: The maximum penetration allowed. @@ -194,19 +209,19 @@ def estimate_good_contact_parameters( zero_data = js.data.JaxSimModelData.build(model=model) W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2] if model.floating_base(): - W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1] + W_pz_C = contact_point_positions(model=model, data=zero_data)[:, -1] W_pz_CoM = W_pz_CoM - W_pz_C.min() # Consider as default a 1% of the model center of mass height. max_penetration = 0.01 * W_pz_CoM - nc = number_of_active_collidable_points_steady_state + nc = number_of_active_contact_points_steady_state return model.contact_model._parameters_class().build_default_from_jaxsim_model( model=model, standard_gravity=standard_gravity, static_friction_coefficient=static_friction_coefficient, max_penetration=max_penetration, - number_of_active_collidable_points_steady_state=nc, + number_of_active_contact_points_steady_state=nc, damping_ratio=damping_ratio, ) @@ -225,34 +240,41 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt The stacked SE(3) matrices of all enabled collidable points. Note: + The output shape is (nL, 3, 4, 4), where nL is the number of links. + Three candidate contact points are considered for each collidable shape. Each collidable point is implicitly associated with a frame :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the collidable point and :math:`[L]` is the orientation frame of the link it is rigidly attached to. """ - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - # Get the transforms of the parent link of all collidable points. - W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_points] - - L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_points - ] + W_H_L = data._link_transforms + + # Index transforms by the body (parent link) of each collision shape + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_H_L_indexed = W_H_L[body_indices] + + def _process_single_shape(shape_type, shape_size, shape_transform, W_H_Li): + # Apply the collision shape transform to get W_H_S + W_H_S = W_H_Li @ shape_transform + + _, W_H_C = jax.lax.switch( + shape_type, + (detection.box_plane, detection.cylinder_plane, detection.sphere_plane), + model.terrain, + shape_size, + W_H_S, + ) - # Build the link-to-point transform from the displacement between the link frame L - # and the implicit contact frame C. - L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci) + return W_H_C - # Compose the work-to-link and link-to-point transforms. - return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) + return jax.vmap(_process_single_shape)( + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + model.kin_dyn_parameters.contact_parameters.transform, + W_H_L_indexed, + ) @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) @@ -283,69 +305,48 @@ def jacobian( rigidly attached to. """ - output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation - ) - - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + output_vel_repr = output_vel_repr or data.velocity_representation - # Compute the Jacobians of all links. + # Compute link-level Jacobians (n_links, 6, 6+n) W_J_WL = js.model.generalized_free_floating_jacobian( model=model, data=data, output_vel_repr=VelRepr.Inertial ) - # Compute the contact Jacobian. - # In inertial-fixed output representation, the Jacobian of the parent link is also - # the Jacobian of the frame C implicitly associated with the collidable point. - W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points] + # Compute contact transforms (n_shapes, n_contacts_per_shape, 4, 4) + W_H_C = transforms(model=model, data=data) + + # Index Jacobians by the body (parent link) of each collision shape + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_J_WL_indexed = W_J_WL[body_indices] # (n_shapes, 6, 6+n) + + # Repeat for each contact point per shape: (n_shapes*n_contacts_per_shape, 6, 6+n) + W_J_WC_flat = jnp.repeat(W_J_WL_indexed, 3, axis=0) + + # Flatten contact transforms (n_shapes*n_contacts_per_shape, 4, 4) + W_H_C_flat = W_H_C.reshape(-1, 4, 4) - # Adjust the output representation. + # Transform Jacobian based on velocity representation match output_vel_repr: case VelRepr.Inertial: - O_J_WC = W_J_WC + return W_J_WC_flat case VelRepr.Body: - W_H_C = transforms(model=model, data=data) - - def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: - C_X_W = jaxsim.math.Adjoint.from_transform( - transform=W_H_C, inverse=True - ) - C_J_WC = C_X_W @ W_J_WC - return C_J_WC - - O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC) + def transform_jacobian(H_C, J_WC): + return jaxsim.math.Adjoint.from_transform(H_C, inverse=True) @ J_WC case VelRepr.Mixed: - W_H_C = transforms(model=model, data=data) - - def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: - - W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) - - CW_X_W = jaxsim.math.Adjoint.from_transform( - transform=W_H_CW, inverse=True - ) - - CW_J_WC = CW_X_W @ W_J_WC - return CW_J_WC - - O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC) + def transform_jacobian(H_C, J_WC): + H_CW = H_C.at[0:3, 0:3].set(jnp.eye(3)) + return jaxsim.math.Adjoint.from_transform(H_CW, inverse=True) @ J_WC case _: - raise ValueError(output_vel_repr) + raise ValueError(f"Unsupported velocity representation: {output_vel_repr}") - return O_J_WC + # Single vmap over all contact points + return jax.vmap(transform_jacobian)(W_H_C_flat, W_J_WC_flat) @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) @@ -373,39 +374,28 @@ def jacobian_derivative( velocity representation. """ - output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation - ) - - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - # Get the index of the parent link and the position of the collidable point. - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + output_vel_repr = output_vel_repr or data.velocity_representation - L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_points - ] + # Get the link velocities. + W_v_WL = data._link_velocities - # Get the transforms of all the parent links. - W_H_Li = data._link_transforms + # Index link velocities by body (parent link) of each collision shape + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_v_WL_indexed = W_v_WL[body_indices] # (n_shapes, 6) - # Get the link velocities. - W_v_WLi = data._link_velocities + # Compute the contact transforms (n_shapes, n_contacts, 4, 4) + W_H_C = transforms(model=model, data=data) # ===================================================== # Compute quantities to adjust the input representation # ===================================================== - def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix: + def compute_T(X: jtp.Matrix) -> jtp.Matrix: In = jnp.eye(model.dofs()) T = jax.scipy.linalg.block_diag(X, In) return T - def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: + def compute_Ṫ(Ẋ: jtp.Matrix) -> jtp.Matrix: On = jnp.zeros(shape=(model.dofs(), model.dofs())) Ṫ = jax.scipy.linalg.block_diag(Ẋ, On) return Ṫ @@ -414,99 +404,71 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: # time derivative. match data.velocity_representation: case VelRepr.Inertial: - W_H_W = jnp.eye(4) - W_X_W = Adjoint.from_transform(transform=W_H_W) - W_Ẋ_W = jnp.zeros((6, 6)) - - T = compute_T(model=model, X=W_X_W) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) - + W_X = Adjoint.from_transform(jnp.eye(4)) + W_Ẋ = jnp.zeros((6, 6)) case VelRepr.Body: - W_H_B = data._base_transform - W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity - B_vx_WB = Cross.vx(B_v_WB) - W_Ẋ_B = W_X_B @ B_vx_WB - - T = compute_T(model=model, X=W_X_B) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) - + W_X = Adjoint.from_transform(data.base_transform) + W_Ẋ = W_X @ Cross.vx(data.base_velocity) case VelRepr.Mixed: - W_H_B = data._base_transform - W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - BW_vx_W_BW = Cross.vx(BW_v_W_BW) - W_Ẋ_BW = W_X_BW @ BW_vx_W_BW - - T = compute_T(model=model, X=W_X_BW) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) - + W_H_BW = data.base_transform.at[0:3, 0:3].set(jnp.eye(3)) + W_X_BW = Adjoint.from_transform(W_H_BW) + BW_v_W_BW = data.base_velocity.at[3:6].set(0) + W_X = W_X_BW + W_Ẋ = W_X_BW @ Cross.vx(BW_v_W_BW) case _: raise ValueError(data.velocity_representation) + T = compute_T(W_X) + Ṫ = compute_Ṫ(W_Ẋ) + # ===================================================== # Compute quantities to adjust the output representation # ===================================================== with data.switch_velocity_representation(VelRepr.Inertial): # Compute the Jacobian of the parent link in inertial representation. - W_J_WL_W = js.model.generalized_free_floating_jacobian( - model=model, - data=data, - ) + W_J_WL_W = js.model.generalized_free_floating_jacobian(model=model, data=data) + # Compute the Jacobian derivative of the parent link in inertial representation. W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative( - model=model, - data=data, + model=model, data=data ) - def compute_O_J̇_WC_I( - L_p_C: jtp.Vector, - parent_link_idx: jtp.Int, - W_H_L: jtp.Matrix, - ) -> jtp.Matrix: + # Index Jacobians by body (parent link) of each collision shape + W_J_WL_W_indexed = W_J_WL_W[body_indices] # (n_shapes, 6, 6+n) + W_J̇_WL_W_indexed = W_J̇_WL_W[body_indices] # (n_shapes, 6, 6+n) + def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: match output_vel_repr: case VelRepr.Inertial: - O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841 - transform=jnp.eye(4) - ) - O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841 - + O_X_W = jnp.eye(6) + O_Ẋ_W = jnp.zeros((6, 6)) case VelRepr.Body: - L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) - W_H_C = W_H_L[parent_link_idx] @ L_H_C - O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) - W_v_WC = W_v_WLi[parent_link_idx] - W_vx_WC = Cross.vx(W_v_WC) - O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841 - + O_X_W = Adjoint.from_transform(W_H_C, inverse=True) + O_Ẋ_W = -O_X_W @ Cross.vx(W_v_WL) case VelRepr.Mixed: - L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) - W_H_C = W_H_L[parent_link_idx] @ L_H_C W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) - CW_H_W = Transform.inverse(W_H_CW) - O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W) - CW_v_WC = CW_X_W @ W_v_WLi[parent_link_idx] - W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3]) - W_vx_W_CW = Cross.vx(W_v_W_CW) - O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841 - + O_X_W = Adjoint.from_transform(Transform.inverse(W_H_CW)) + v_CW = O_X_W @ W_v_WL + O_Ẋ_W = -O_X_W @ Cross.vx(v_CW.at[:3].set(v_CW[:3])) case _: raise ValueError(output_vel_repr) - O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs())) - O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T - O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T - O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ + O_J̇_WC_I = O_Ẋ_W @ W_J_WL_W @ T + O_J̇_WC_I += O_X_W @ W_J̇_WL_W @ T + O_J̇_WC_I += O_X_W @ W_J_WL_W @ Ṫ return O_J̇_WC_I - O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, None))( - L_p_Ci, parent_link_idx_of_enabled_collidable_points, W_H_Li - ) + O_J̇_per_shape = jax.vmap( + lambda H_C_shape, v_WL_shape, J_WL_shape, J̇_WL_shape: jax.vmap( + compute_O_J̇_WC_I, + in_axes=(0, None, None, None), # Map over contacts for W_H_C only + )(H_C_shape, v_WL_shape, J_WL_shape, J̇_WL_shape), + in_axes=(0, 0, 0, 0), # Map over shapes + )(W_H_C, W_v_WL_indexed, W_J_WL_W_indexed, W_J̇_WL_W_indexed) + + O_J̇_WC = O_J̇_per_shape.reshape(-1, 6, 6 + model.dofs()) return O_J̇_WC @@ -537,7 +499,7 @@ def link_contact_forces( """ # Compute the contact forces for each collidable point with the active contact model. - W_f_C, aux_dict = model.contact_model.compute_contact_forces( + W_f_L, aux_dict = model.contact_model.compute_contact_forces( model=model, data=data, **( @@ -549,7 +511,7 @@ def link_contact_forces( # Compute the 6D forces applied to the links equivalent to the forces applied # to the frames associated to the collidable points. - W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) + # W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) return W_f_L, aux_dict @@ -575,8 +537,8 @@ def link_forces_from_contact_forces( contact_parameters = model.kin_dyn_parameters.contact_parameters # Extract the indices corresponding to the enabled collidable points. - indices_of_enabled_collidable_points = ( - contact_parameters.indices_of_enabled_collidable_points + indices_of_enabled_contact_points = ( + contact_parameters.indices_of_enabled_contact_points ) # Convert the contact forces to a JAX array. @@ -585,13 +547,13 @@ def link_forces_from_contact_forces( # Construct the vector defining the parent link index of each collidable point. # We use this vector to sum the 6D forces of all collidable points rigidly # attached to the same link. - parent_link_index_of_collidable_points = jnp.array( - contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + parent_link_index_of_contact_points = jnp.array(contact_parameters.body, dtype=int)[ + indices_of_enabled_contact_points + ] # Create the mask that associate each collidable point to their parent link. # We use this mask to sum the collidable points to the right link. - mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( + mask = parent_link_index_of_contact_points[:, jnp.newaxis] == jnp.arange( model.number_of_links() ) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 113620f89..cbae3d145 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -176,7 +176,7 @@ def build( if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts): contact_state["tangential_deformation"] = contact_state.get( "tangential_deformation", - jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point), + jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.center), ) model_data = JaxSimModelData( diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index e111354c2..981a3d2c5 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -14,8 +14,15 @@ import jaxsim.typing as jtp from jaxsim.math import Inertia, JointModel, supported_joint_motion from jaxsim.math.adjoint import Adjoint -from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription -from jaxsim.utils import HashedNumpyArray, JaxsimDataclass +from jaxsim.parsers.descriptions import ( + BoxCollision, + CylinderCollision, + JointDescription, + JointType, + ModelDescription, + SphereCollision, +) +from jaxsim.utils import CollidableShapeType, HashedNumpyArray, JaxsimDataclass @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) @@ -762,6 +769,13 @@ def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix: return jnp.atleast_2d(jnp.where(I, I, I.T)).astype(float) +_COLLISION_SHAPE_MAP = { + SphereCollision: CollidableShapeType.Sphere, + BoxCollision: CollidableShapeType.Box, + CylinderCollision: CollidableShapeType.Cylinder, +} + + @jax_dataclasses.pytree_dataclass class ContactParameters(JaxsimDataclass): """ @@ -769,14 +783,15 @@ class ContactParameters(JaxsimDataclass): Attributes: body: - A tuple of integers representing, for each collidable point, the index of - the body (link) to which it is rigidly attached to. - point: - The translations between the link frame and the collidable point, expressed - in the coordinates of the parent link frame. - enabled: - A tuple of booleans representing, for each collidable point, whether it is - enabled or not in contact models. + A tuple of integers representing, for each collision shape, the index of + the link to which it is rigidly attached to. + transform: + The 4x4 homogeneous transformation matrices representing the pose of each + collision shape with respect to the parent link frame. + shape_size: + The size parameters of each collidable shape. + shape_type: + The type of each collidable shape (sphere, box, cylinder, etc.). Note: Contrarily to LinkParameters and JointParameters, this class is not meant @@ -785,16 +800,23 @@ class ContactParameters(JaxsimDataclass): body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple) - point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([])) + transform: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([])) + shape_size: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) + shape_type: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) - enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple) + @property + def center(self) -> jtp.Array: + """Extract translation vectors from transformation matrices.""" + if self.transform.size == 0: + return jnp.array([]) + return self.transform[:, :3, 3] @property - def indices_of_enabled_collidable_points(self) -> npt.NDArray: - """ - Return the indices of the enabled collidable points. - """ - return np.where(np.array(self.enabled))[0] + def orientation(self) -> jtp.Array: + """Extract rotation matrices from transformation matrices.""" + if self.transform.size == 0: + return jnp.array([]) + return self.transform[:, :3, :3] @staticmethod def build_from(model_description: ModelDescription) -> ContactParameters: @@ -811,33 +833,40 @@ def build_from(model_description: ModelDescription) -> ContactParameters: if len(model_description.collision_shapes) == 0: return ContactParameters() - # Get all the links so that we can take their updated index. - links_dict = {link.name: link for link in model_description} + shape_types, shape_sizes, transforms, parent_link_indices = ( + [], + [], + [], + [], + ) - # Get all the enabled collidable points of the model. - collidable_points = model_description.all_enabled_collidable_points() + for collision in model_description.collision_shapes: + shape_type = _COLLISION_SHAPE_MAP.get( + type(collision), CollidableShapeType.Unsupported + ) - # Extract the positions L_p_C of the collidable points w.r.t. the link frames - # they are rigidly attached to. - points = jnp.vstack([cp.position for cp in collidable_points]) + # Skip unsupported collision shapes + if shape_type == CollidableShapeType.Unsupported: + continue - # Extract the indices of the links to which the collidable points are rigidly - # attached to. - link_index_of_points = tuple( - links_dict[cp.parent_link.name].index for cp in collidable_points - ) + shape_types.append(shape_type) - # Build the ContactParameters object. - cp = ContactParameters( - point=points, - body=link_index_of_points, - enabled=tuple(True for _ in link_index_of_points), - ) + shape_sizes.append(collision.size.squeeze()) - assert cp.point.shape[1] == 3, cp.point.shape[1] - assert cp.point.shape[0] == len(cp.body), cp.point.shape[0] + transforms.append(collision.transform) - return cp + # Get the parent link index for this collision shape. + parent_link_indices.append( + model_description.links_dict[collision.parent_link].index + ) + + # Build the ContactParameters object. + return ContactParameters( + body=tuple(parent_link_indices), + transform=jnp.array(transforms, dtype=float), + shape_type=jnp.array(shape_types, dtype=int), + shape_size=jnp.array(shape_sizes, dtype=float), + ) @jax_dataclasses.pytree_dataclass diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index c4ea9a56d..25f14dd20 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -54,7 +54,7 @@ def system_acceleration( W_f_L_terrain = jnp.zeros_like(f_L) contact_state_derivative = {} - if len(model.kin_dyn_parameters.contact_parameters.body) > 0: + if len(model.kin_dyn_parameters.contact_parameters.center) > 0: # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. diff --git a/src/jaxsim/parsers/descriptions/__init__.py b/src/jaxsim/parsers/descriptions/__init__.py index ff3bf631d..6a08ae6f3 100644 --- a/src/jaxsim/parsers/descriptions/__init__.py +++ b/src/jaxsim/parsers/descriptions/__init__.py @@ -1,8 +1,6 @@ from .collision import ( BoxCollision, - CollidablePoint, - CollisionShape, - MeshCollision, + CylinderCollision, SphereCollision, ) from .joint import JointDescription, JointGenericAxis, JointType diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 719c92d2b..cf63f71e7 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -1,178 +1,69 @@ from __future__ import annotations -import abc import dataclasses +from abc import ABC -import jax.numpy as jnp import numpy as np -import numpy.typing as npt import jaxsim.typing as jtp -from jaxsim import logging - -from .link import LinkDescription @dataclasses.dataclass -class CollidablePoint: +class CollisionShape(ABC): """ - Represents a collidable point associated with a parent link. + Base class for collision shapes. - Attributes: - parent_link: The parent link to which the collidable point is attached. - position: The position of the collidable point relative to the parent link. - enabled: A flag indicating whether the collidable point is enabled for collision detection. + This class serves as a base for specific collision shapes like BoxCollision and SphereCollision. + It is not intended to be instantiated directly. """ - parent_link: LinkDescription - position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3)) - enabled: bool = True - - def change_link( - self, new_link: LinkDescription, new_H_old: npt.NDArray - ) -> CollidablePoint: - """ - Move the collidable point to a new parent link. - - Args: - new_link (LinkDescription): The new parent link to which the collidable point is moved. - new_H_old (npt.NDArray): The transformation matrix from the new link's frame to the old link's frame. - - Returns: - CollidablePoint: A new collidable point associated with the new parent link. - """ - - msg = f"Moving collidable point: {self.parent_link.name} -> {new_link.name}" - logging.debug(msg=msg) - - return CollidablePoint( - parent_link=new_link, - position=(new_H_old @ jnp.hstack([self.position, 1.0])).squeeze()[0:3], - enabled=self.enabled, - ) + size: jtp.VectorLike + parent_link: str + transform: jtp.MatrixLike = dataclasses.field(default_factory=lambda: np.eye(4)) def __hash__(self) -> int: - return hash( ( + hash(tuple(self.size.tolist())), hash(self.parent_link), - hash(tuple(self.position.tolist())), - hash(self.enabled), + hash(tuple(self.transform.flatten().tolist())), ) ) - def __eq__(self, other: CollidablePoint) -> bool: + def __eq__(self, other: CollisionShape) -> bool: - if not isinstance(other, CollidablePoint): + if not isinstance(other, CollisionShape): return False return hash(self) == hash(other) - def __str__(self) -> str: - return ( - f"{self.__class__.__name__}(" - + f"parent_link={self.parent_link.name}" - + f", position={self.position}" - + f", enabled={self.enabled}" - + ")" - ) - - -@dataclasses.dataclass -class CollisionShape(abc.ABC): - """ - Abstract base class for representing collision shapes. - - Attributes: - collidable_points: A list of collidable points associated with the collision shape. - """ - - collidable_points: tuple[CollidablePoint] + @property + def center(self) -> jtp.Vector: + """Extract the translation from the transformation matrix.""" + return self.transform[:3, 3] - def __str__(self): - return ( - f"{self.__class__.__name__}(" - + "collidable_points=[\n " - + ",\n ".join(str(cp) for cp in self.collidable_points) - + "\n])" - ) + @property + def orientation(self) -> jtp.Matrix: + """Extract the rotation matrix from the transformation matrix.""" + return self.transform[:3, :3] @dataclasses.dataclass class BoxCollision(CollisionShape): """ Represents a box-shaped collision shape. - - Attributes: - center: The center of the box in the local frame of the collision shape. """ - center: jtp.VectorLike - - def __hash__(self) -> int: - return hash( - ( - hash(super()), - hash(tuple(self.center.tolist())), - ) - ) - - def __eq__(self, other: BoxCollision) -> bool: - - if not isinstance(other, BoxCollision): - return False - - return hash(self) == hash(other) - @dataclasses.dataclass class SphereCollision(CollisionShape): """ Represents a spherical collision shape. - - Attributes: - center: The center of the sphere in the local frame of the collision shape. """ - center: jtp.VectorLike - - def __hash__(self) -> int: - return hash( - ( - hash(super()), - hash(tuple(self.center.tolist())), - ) - ) - - def __eq__(self, other: BoxCollision) -> bool: - - if not isinstance(other, BoxCollision): - return False - - return hash(self) == hash(other) - @dataclasses.dataclass -class MeshCollision(CollisionShape): +class CylinderCollision(CollisionShape): """ - Represents a mesh-shaped collision shape. - - Attributes: - center: The center of the mesh in the local frame of the collision shape. + Represents a cylindrical collision shape. """ - - center: jtp.VectorLike - - def __hash__(self) -> int: - return hash( - ( - hash(tuple(self.center.tolist())), - hash(self.collidable_points), - ) - ) - - def __eq__(self, other: MeshCollision) -> bool: - if not isinstance(other, MeshCollision): - return False - - return hash(self) == hash(other) diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 6b3fdb02d..624f6b22c 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -1,13 +1,11 @@ from __future__ import annotations import dataclasses -import itertools from collections.abc import Sequence from jaxsim import logging from ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose -from .collision import CollidablePoint, CollisionShape from .joint import JointDescription from .link import LinkDescription @@ -27,9 +25,7 @@ class ModelDescription(KinematicGraph): fixed_base: bool = True - collision_shapes: tuple[CollisionShape, ...] = dataclasses.field( - default_factory=list, repr=False - ) + collision_shapes: tuple = dataclasses.field(default_factory=list, repr=False) @staticmethod def build_model_from( @@ -37,7 +33,7 @@ def build_model_from( links: list[LinkDescription], joints: list[JointDescription], frames: list[LinkDescription] | None = None, - collisions: tuple[CollisionShape, ...] = (), + collisions: tuple = (), fixed_base: bool = False, base_link_name: str | None = None, considered_joints: Sequence[str] | None = None, @@ -80,61 +76,55 @@ def build_model_from( fk = KinematicGraphTransforms(graph=kinematic_graph) # Container of the final model's collision shapes. - final_collisions: list[CollisionShape] = [] + final_collisions: list = [] # Move and express the collision shapes of removed links to the resulting # lumped link that replace the combination of the removed link and its parent. for collision_shape in collisions: - # Get all the collidable points of the shape - coll_points = tuple(collision_shape.collidable_points) - - # Assume they have an unique parent link - if not len(set({cp.parent_link.name for cp in coll_points})) == 1: - msg = "Collision shape not currently supported (multiple parent links)" - raise RuntimeError(msg) - # Get the parent link of the collision shape. # Note that this link could have been lumped and we need to find the # link in which it was lumped into. - parent_link_of_shape = collision_shape.collidable_points[0].parent_link + parent_link_of_shape = collision_shape.parent_link # If it is part of the (reduced) graph, add it as it is... - if parent_link_of_shape.name in kinematic_graph.link_names(): + if parent_link_of_shape in kinematic_graph.link_names(): final_collisions.append(collision_shape) continue # ... otherwise look for the frame - if parent_link_of_shape.name not in kinematic_graph.frame_names(): + if parent_link_of_shape not in kinematic_graph.frame_names(): msg = "Parent frame '{}' of collision shape not found, ignoring shape" - logging.info(msg.format(parent_link_of_shape.name)) + logging.info(msg.format(parent_link_of_shape)) continue - # Create a new collision shape - new_collision_shape = CollisionShape(collidable_points=()) - final_collisions.append(new_collision_shape) + # Find the link that is part of the (reduced) model in which the + # collision shape's parent was lumped into. + real_parent_link_name = kinematic_graph.frames_dict[ + parent_link_of_shape + ].parent_name + + # Get the transform from the real parent link to the removed link + # that still exists as a frame. + parent_H_frame = fk.relative_transform( + relative_to=real_parent_link_name, + name=parent_link_of_shape, + ) - # If the frame was found, update the collidable points' pose and add them - # to the new collision shape. - for cp in collision_shape.collidable_points: - # Find the link that is part of the (reduced) model in which the - # collision shape's parent was lumped into - real_parent_link_name = kinematic_graph.frames_dict[ - parent_link_of_shape.name - ].parent_name - - # Change the link associated to the collidable point, updating their - # relative pose - moved_cp = cp.change_link( - new_link=kinematic_graph.links_dict[real_parent_link_name], - new_H_old=fk.relative_transform( - relative_to=real_parent_link_name, - name=cp.parent_link.name, - ), - ) - - # Store the updated collision. - new_collision_shape.collidable_points += (moved_cp,) + # Transform the collision shape's pose to the new parent link frame. + # The collision shape was defined w.r.t. the removed link (now a frame). + # Now we need to express it w.r.t. the link that absorbed the removed link. + # Compose the transforms: parent_H_shape = parent_H_frame @ frame_H_shape + parent_H_shape = parent_H_frame @ collision_shape.transform + + # Create a new collision shape with updated pose and parent link + new_collision_shape = dataclasses.replace( + collision_shape, + transform=parent_H_shape, + parent_link=real_parent_link_name, + ) + + final_collisions.append(new_collision_shape) # Build the model model = ModelDescription( @@ -193,63 +183,6 @@ def reduce(self, considered_joints: Sequence[str]) -> ModelDescription: return reduced_model_description - def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None: - """ - Enable or disable collision shapes associated with a link. - - Args: - link_name: The name of the link. - enabled: Enable or disable collision shapes associated with the link. - """ - - if link_name not in self.link_names(): - raise ValueError(link_name) - - for point in self.collision_shape_of_link( - link_name=link_name - ).collidable_points: - point.enabled = enabled - - def collision_shape_of_link(self, link_name: str) -> CollisionShape: - """ - Get the collision shape associated with a specific link. - - Args: - link_name: The name of the link. - - Returns: - The collision shape associated with the link. - """ - - if link_name not in self.link_names(): - raise ValueError(link_name) - - return CollisionShape( - collidable_points=[ - point - for shape in self.collision_shapes - for point in shape.collidable_points - if point.parent_link.name == link_name - ] - ) - - def all_enabled_collidable_points(self) -> list[CollidablePoint]: - """ - Get all enabled collidable points in the model. - - Returns: - The list of all enabled collidable points. - - """ - - # Get iterator of all collidable points - all_collidable_points = itertools.chain.from_iterable( - [shape.collidable_points for shape in self.collision_shapes] - ) - - # Return enabled collidable points - return [cp for cp in all_collidable_points if cp.enabled] - def __eq__(self, other: ModelDescription) -> bool: if not isinstance(other, ModelDescription): diff --git a/src/jaxsim/parsers/rod/meshes.py b/src/jaxsim/parsers/rod/meshes.py deleted file mode 100644 index 3679597e8..000000000 --- a/src/jaxsim/parsers/rod/meshes.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -import trimesh - -VALID_AXIS = {"x": 0, "y": 1, "z": 2} - - -def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray: - """ - Extract the vertices of a mesh as points. - """ - return mesh.vertices - - -def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray: - """ - Extract N random points from the surface of a mesh. - - Args: - mesh: The mesh from which to extract points. - n: The number of points to extract. - - Returns: - The extracted points (N x 3 array). - """ - - return mesh.sample(n) - - -def extract_points_uniform_surface_sampling( - mesh: trimesh.Trimesh, n: int -) -> np.ndarray: - """ - Extract N uniformly sampled points from the surface of a mesh. - - Args: - mesh: The mesh from which to extract points. - n: The number of points to extract. - - Returns: - The extracted points (N x 3 array). - """ - - return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0] - - -def extract_points_select_points_over_axis( - mesh: trimesh.Trimesh, axis: str, direction: str, n: int -) -> np.ndarray: - """ - Extract N points from a mesh along a specified axis. The points are selected based on their position along the axis. - - Args: - mesh: The mesh from which to extract points. - axis: The axis along which to extract points. - direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower". - n: The number of points to extract. - - Returns: - The extracted points (N x 3 array). - """ - - dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]} - arr = mesh.vertices - - # Sort rows lexicographically first, then columnar. - arr.sort(axis=0) - sorted_arr = arr[dirs[direction]] - return sorted_arr - - -def extract_points_aap( - mesh: trimesh.Trimesh, - axis: str, - upper: float | None = None, - lower: float | None = None, -) -> np.ndarray: - """ - Extract points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis. - - Args: - mesh: The mesh from which to extract points. - axis: The axis along which to extract points. - upper: The upper bound of the range. - lower: The lower bound of the range. - - Returns: - The extracted points (N x 3 array). - - Raises: - AssertionError: If the lower bound is greater than the upper bound. - """ - - # Check bounds. - upper = upper if upper is not None else np.inf - lower = lower if lower is not None else -np.inf - assert lower < upper, "Invalid bounds for axis-aligned plane" - - # Logic. - points = mesh.vertices[ - (mesh.vertices[:, VALID_AXIS[axis]] >= lower) - & (mesh.vertices[:, VALID_AXIS[axis]] <= upper) - ] - - return points diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 99d779d3e..aab6f48c3 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -27,7 +27,7 @@ class SDFData(NamedTuple): link_descriptions: list[descriptions.LinkDescription] joint_descriptions: list[descriptions.JointDescription] frame_descriptions: list[descriptions.LinkDescription] - collision_shapes: list[descriptions.CollisionShape] + collision_shapes: list sdf_model: rod.Model | None = None model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose() @@ -308,10 +308,12 @@ def extract_model_data( # ================ # Initialize the collision shapes - collisions: list[descriptions.CollisionShape] = [] + collisions = [] # Parse the collisions for link in sdf_model.links(): + # If a link has multiple collision shapes, we consider only the first + # supported one with priority box > sphere > cylinder. for collision in link.collisions(): if collision.geometry.box is not None: box_collision = utils.create_box_collision( @@ -320,6 +322,7 @@ def extract_model_data( ) collisions.append(box_collision) + break if collision.geometry.sphere is not None: sphere_collision = utils.create_sphere_collision( @@ -328,18 +331,28 @@ def extract_model_data( ) collisions.append(sphere_collision) + break - if collision.geometry.mesh is not None and int( - os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0") + if collision.geometry.cylinder is not None and int( + os.environ.get("JAXSIM_ENABLE_CYLINDER_COLLISION", 0) ): - logging.warning("Mesh collision support is still experimental.") - mesh_collision = utils.create_mesh_collision( + cylinder_collision = utils.create_cylinder_collision( collision=collision, link_description=links_dict[link.name], - method=utils.meshes.extract_points_vertices, ) - collisions.append(mesh_collision) + collisions.append(cylinder_collision) + break + + else: + # Fill with unsupported collision shape + collisions.append( + descriptions.collision.CollisionShape( + transform=jnp.eye(4), + size=jnp.array([0.0, 0.0, 0.0]), + parent_link=link.name, + ) + ) return SDFData( model_name=sdf_model.name, diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index a295b7fab..2a2eb0d63 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -1,21 +1,9 @@ -import os -import pathlib -from collections.abc import Callable -from typing import TypeVar - import numpy as np -import numpy.typing as npt import rod -import trimesh -from rod.utils.resolve_uris import resolve_local_uri import jaxsim.typing as jtp -from jaxsim import logging from jaxsim.math import Adjoint, Inertia from jaxsim.parsers import descriptions -from jaxsim.parsers.rod import meshes - -MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray]) def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: @@ -115,43 +103,12 @@ def create_box_collision( x, y, z = collision.geometry.box.size - center = np.array([x / 2, y / 2, z / 2]) - - # Define the bottom corners. - bottom_corners = np.array([[0, 0, 0], [x, 0, 0], [x, y, 0], [0, y, 0]]) - - # Conditionally add the top corners based on the environment variable. - top_corners = ( - np.array([[0, 0, z], [x, 0, z], [x, y, z], [0, y, z]]) - if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0").lower() - in { - "false", - "0", - } - else [] - ) - - # Combine and shift by the center - box_corners = np.vstack([bottom_corners, *top_corners]) - center - H = collision.pose.transform() if collision.pose is not None else np.eye(4) - center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1] - box_corners_wrt_link = ( - H @ np.hstack([box_corners, np.vstack([1.0] * box_corners.shape[0])]).T - )[0:3, :] - - collidable_points = [ - descriptions.CollidablePoint( - parent_link=link_description, - position=np.array(corner), - enabled=True, - ) - for corner in box_corners_wrt_link.T - ] - return descriptions.BoxCollision( - collidable_points=collidable_points, center=center_wrt_link + size=np.array([x, y, z]), + transform=H, + parent_link=link_description.name, ) @@ -169,112 +126,38 @@ def create_sphere_collision( The sphere collision description. """ - # From https://stackoverflow.com/a/26127012 - def fibonacci_sphere(samples: int) -> npt.NDArray: - # Get the golden ratio in radians. - phi = np.pi * (3.0 - np.sqrt(5.0)) - - # Generate the points. - points = [ - np.array( - [ - np.cos(phi * i) - * np.sqrt(1 - (y := 1 - 2 * i / (samples - 1)) ** 2), - y, - np.sin(phi * i) * np.sqrt(1 - y**2), - ] - ) - for i in range(samples) - ] - - # Filter to keep only the bottom half if required. - if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0").lower() in { - "true", - "1", - }: - # Keep only the points with z <= 0. - points = [point for point in points if point[2] <= 0] - - return np.vstack(points) - r = collision.geometry.sphere.radius - sphere_points = r * fibonacci_sphere( - samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="50")) - ) - H = collision.pose.transform() if collision.pose is not None else np.eye(4) - center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1] - - sphere_points_wrt_link = ( - H @ np.hstack([sphere_points, np.vstack([1.0] * sphere_points.shape[0])]).T - )[0:3, :] - - collidable_points = [ - descriptions.CollidablePoint( - parent_link=link_description, - position=np.array(point), - enabled=True, - ) - for point in sphere_points_wrt_link.T - ] - return descriptions.SphereCollision( - collidable_points=collidable_points, center=center_wrt_link + size=np.array([r] * 3), + transform=H, + parent_link=link_description.name, ) -def create_mesh_collision( - collision: rod.Collision, - link_description: descriptions.LinkDescription, - method: MeshMappingMethod = None, -) -> descriptions.MeshCollision: +def create_cylinder_collision( + collision: rod.Collision, link_description: descriptions.LinkDescription +) -> descriptions.CylinderCollision: """ - Create a mesh collision from an SDF collision element. + Create a cylinder collision from an SDF collision element. Args: collision: The SDF collision element. link_description: The link description. - method: The method to use for mesh wrapping. Returns: - The mesh collision description. + The cylinder collision description. """ - file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri)) - file_type = file.suffix.replace(".", "") - mesh = trimesh.load_mesh(file, file_type=file_type) + r = collision.geometry.cylinder.radius + l = collision.geometry.cylinder.length - if mesh.is_empty: - raise RuntimeError(f"Failed to process '{file}' with trimesh") + H = collision.pose.transform() if collision.pose is not None else np.eye(4) - mesh.apply_scale(collision.geometry.mesh.scale) - logging.info( - msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{file_type}'" + return descriptions.CylinderCollision( + size=np.array([r, l, 0]), + transform=H, + parent_link=link_description.name, ) - - if method is None: - method = meshes.VertexExtraction() - logging.debug("Using default Vertex Extraction method for mesh wrapping") - else: - logging.debug(f"Using method {method} for mesh wrapping") - - points = method(mesh=mesh) - logging.debug(f"Extracted {len(points)} points from mesh") - - W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4) - - # Extract translation from transformation matrix - W_p_L = W_H_L[:3, 3] - mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L - collidable_points = [ - descriptions.CollidablePoint( - parent_link=link_description, - position=point, - enabled=True, - ) - for point in mesh_points_wrt_link - ] - - return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L) diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 5e0af2a66..022260e1b 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -1,6 +1,5 @@ from . import actuation, contacts from .aba import aba -from .collidable_points import collidable_points_pos_vel from .crba import crba from .forward_kinematics import forward_kinematics_model from .jacobian import ( diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py deleted file mode 100644 index 179126bb6..000000000 --- a/src/jaxsim/rbda/collidable_points.py +++ /dev/null @@ -1,65 +0,0 @@ -import jax -import jax.numpy as jnp - -import jaxsim.api as js -import jaxsim.typing as jtp -from jaxsim.math import Skew - - -def collidable_points_pos_vel( - model: js.model.JaxSimModel, - *, - link_transforms: jtp.Matrix, - link_velocities: jtp.Matrix, -) -> tuple[jtp.Matrix, jtp.Matrix]: - """ - - Compute the position and linear velocity of the enabled collidable points in the world frame. - - Args: - model: The model to consider. - link_transforms: The transforms from the world frame to each link. - link_velocities: The linear and angular velocities of each link. - - Returns: - A tuple containing the position and linear velocity of the enabled collidable points. - """ - - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - - L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_points - ] - - if len(indices_of_enabled_collidable_points) == 0: - return jnp.array(0).astype(float), jnp.empty(0).astype(float) - - def process_point_kinematics( - Li_p_C: jtp.Vector, parent_body: jtp.Int - ) -> tuple[jtp.Vector, jtp.Vector]: - - # Compute the position of the collidable point. - W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3] - - # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}. - CW_vl_WCi = ( - jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()]) - @ link_velocities[parent_body].squeeze() - ) - - return W_p_Ci, CW_vl_WCi - - # Process all the collidable points in parallel. - W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)( - L_p_Ci, - parent_link_idx_of_enabled_collidable_points, - ) - - return W_p_Ci, CW_vl_WC diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py index 32f05e229..af6e1d00e 100644 --- a/src/jaxsim/rbda/contacts/__init__.py +++ b/src/jaxsim/rbda/contacts/__init__.py @@ -1,5 +1,5 @@ from . import relaxed_rigid, rigid, soft -from .common import ContactModel, ContactsParams +from .common import CollidableShapeType, ContactModel, ContactsParams from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams from .rigid import RigidContacts, RigidContactsParams from .soft import SoftContacts, SoftContactsParams diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index cc772f033..daae5b2ba 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -1,66 +1,96 @@ from __future__ import annotations import abc -import functools import jax import jax.numpy as jnp import jaxsim.api as js -import jaxsim.terrain import jaxsim.typing as jtp -from jaxsim.math import STANDARD_GRAVITY -from jaxsim.utils import JaxsimDataclass +from jaxsim.math import STANDARD_GRAVITY, Skew +from jaxsim.utils import CollidableShapeType, JaxsimDataclass try: from typing import Self except ImportError: from typing_extensions import Self +from .detection import box_plane, cylinder_plane, sphere_plane MAX_STIFFNESS = 1e6 MAX_DAMPING = 1e4 +# Define a mapping from collidable shape types to distance functions. +_COLLISION_MAP = { + CollidableShapeType.Sphere: sphere_plane, + CollidableShapeType.Box: box_plane, + CollidableShapeType.Cylinder: cylinder_plane, +} -@functools.partial(jax.jit, static_argnames=("terrain",)) + +@jax.jit def compute_penetration_data( - p: jtp.VectorLike, - v: jtp.VectorLike, - terrain: jaxsim.terrain.Terrain, + model: js.model.JaxSimModel, + *, + shape_transform: jtp.Matrix, + shape_type: CollidableShapeType, + shape_size: jtp.Vector, + link_transforms: jtp.Matrix, + link_velocities: jtp.Matrix, ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]: """ Compute the penetration data (depth, rate, and terrain normal) of a collidable point. Args: - p: The position of the collidable point. - v: - The linear velocity of the point (linear component of the mixed 6D velocity - of the implicit frame `C = (W_p_C, [W])` associated to the point). - terrain: The considered terrain. + model: The model to consider. + shape_transform: The 4x4 transform of the collidable shape with respect to the link frame. + shape_type: The type of the collidable shape. + shape_size: The size parameters of the collidable shape. + link_transforms: The transforms from the world frame to each link. + link_velocities: The linear and angular velocities of each link. Returns: A tuple containing the penetration depth, the penetration velocity, - and the considered terrain normal. + the terrain normal, the contact point position, and the contact point velocity + expressed in mixed representation. """ + W_H_L, W_ṗ_L = link_transforms, link_velocities + + # Apply the collision shape transform. + # This computes W_H_S where S is the collision shape frame. + W_H_S = W_H_L @ shape_transform + # Pre-process the position and the linear velocity of the collidable point. - W_ṗ_C = jnp.array(v).squeeze() - px, py, pz = jnp.array(p).squeeze() + # Note that we consider 3 candidate contact points also for spherical shapes, + # in which the output is padded with zeros. + # This is to allow parallel evaluation of the collision types. + δ, W_H_C = jax.lax.switch( + shape_type, + (box_plane, cylinder_plane, sphere_plane), + model.terrain, + shape_size, + W_H_S, + ) + + W_p_C = W_H_C[:, :3, 3] + n̂ = W_H_C[:, :3, 2] + + def process_shape_kinematics(W_p_Ci: jtp.Vector) -> jtp.Vector: + + # Compute the velocity of the contact points. + CW_ṗ_Ci = jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()]) @ W_ṗ_L - # Compute the terrain normal and the contact depth. - n̂ = terrain.normal(x=px, y=py).squeeze() - h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz]) + return CW_ṗ_Ci - # Compute the penetration depth normal to the terrain. - δ = jnp.maximum(0.0, jnp.dot(h, n̂)) + CW_ṗ_C = jax.vmap(process_shape_kinematics)(W_p_C) - # Compute the penetration normal velocity. - δ_dot = -jnp.dot(W_ṗ_C, n̂) + δ = jnp.maximum(0.0, -δ) - # Enforce the penetration rate to be zero when the penetration depth is zero. - δ_dot = jnp.where(δ > 0, δ_dot, 0.0) + δ̇ = -jax.vmap(jnp.dot)(CW_ṗ_C, n̂) + δ̇ = jnp.where(δ > 0, δ̇, 0.0) - return δ, δ_dot, n̂ + return δ, δ̇, n̂, W_p_C, CW_ṗ_C class ContactsParams(JaxsimDataclass): @@ -94,7 +124,6 @@ def build_default_from_jaxsim_model( standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, max_penetration: jtp.FloatLike = 0.001, - number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, p: jtp.FloatLike = 0.5, q: jtp.FloatLike = 0.5, @@ -110,8 +139,6 @@ def build_default_from_jaxsim_model( standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. max_penetration: The maximum penetration depth. - number_of_active_collidable_points_steady_state: - The number of active collidable points in steady state. damping_ratio: The damping ratio. p: The first parameter of the contact model. q: The second parameter of the contact model. @@ -137,7 +164,6 @@ def build_default_from_jaxsim_model( ξ = damping_ratio δ_max = max_penetration μc = static_friction_coefficient - nc = number_of_active_collidable_points_steady_state # Compute the total mass of the model. m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() @@ -147,7 +173,7 @@ def build_default_from_jaxsim_model( # the damping term of the Hunt/Crossley model. if stiffness is None: # Compute the average support force on each collidable point. - f_average = m * standard_gravity / nc + f_average = m * standard_gravity stiffness = f_average / jnp.power(δ_max, 1 + p) stiffness = jnp.clip(stiffness, 0, MAX_STIFFNESS) diff --git a/src/jaxsim/rbda/contacts/detection.py b/src/jaxsim/rbda/contacts/detection.py new file mode 100644 index 000000000..0c6b2180b --- /dev/null +++ b/src/jaxsim/rbda/contacts/detection.py @@ -0,0 +1,210 @@ +import jax +import jax.numpy as jnp + +import jaxsim +import jaxsim.typing as jtp + + +def _contact_frame(normal: jtp.Vector, position: jtp.Vector) -> jtp.Matrix: + """Create a contact frame with z-axis aligned with the contact normal.""" + n = normal / jaxsim.math.safe_norm(normal) + + t1_initial = jnp.array([1.0, 0.0, 0.0]) + + t1 = t1_initial - jnp.dot(t1_initial, n) * n + t1 = t1 / jaxsim.math.safe_norm(t1) + t2 = jnp.cross(n, t1) + + R = jnp.stack([t1, t2, n], axis=1) + + return jaxsim.math.Transform.from_rotation_and_translation( + rotation=R, + translation=position, + ) + + +def sphere_plane( + terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.Matrix +) -> tuple[jtp.Float, jtp.Matrix]: + """ + Detect contacts between a sphere and a plane terrain. + + Args: + terrain: The terrain object. + size: The size of the sphere. + W_H_L: The collision shape transform in world coordinates. + + Returns: + A tuple containing the distance from the sphere to the plane and the pose transform + of the contact frame. + """ + # Extract sphere center and radius. + center = W_H_L[0:3, 3] + radius = size[0] + + # Extract terrain properties at sphere center. + x, y = center[0], center[1] + + normal = terrain.normal(x=x, y=y) + height = terrain.height(x=x, y=y) + + distance = jnp.dot(center - height, normal) - radius + + position = center - radius * normal + + W_H_C = _contact_frame(normal, position) + + # Pad distance and transform to match expected output shapes. + # and allow parallel evaluation of the collision types. + distance = jnp.pad(jnp.array([distance]), (0, 2), mode="empty") + W_H_C = jnp.pad(W_H_C[jnp.newaxis, ...], ((0, 2), (0, 0), (0, 0)), mode="empty") + + return distance, W_H_C + + +def box_plane( + terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.Matrix +) -> tuple[jtp.Vector, jtp.Matrix]: + """ + Return distances and contact frames of the 3 deepest corners of a box on terrain using SDF. + Fully vectorized, works for any box orientation. + """ + half_size = size.squeeze() / 2 + + R = W_H_L[:3, :3] + center = W_H_L[:3, 3] + + # Generate all 8 corners using meshgrid + sx = jnp.array([-half_size[0], half_size[0]]) + sy = jnp.array([-half_size[1], half_size[1]]) + sz = jnp.array([-half_size[2], half_size[2]]) + xs, ys, zs = jnp.meshgrid(sx, sy, sz, indexing="ij") + corners_local = jnp.stack( + [xs.ravel(), ys.ravel(), zs.ravel()], axis=1 + ) # shape (8,3) + + # Project box z-axis on terrain normal and ensure direction away from plane + sign = jnp.sign(R[:, 2]) + R_corrected = R.at[:, 2].set(R[:, 2] * sign) + + # Transform to world frame + corners_world = center + corners_local @ R_corrected.T + + # Vectorized terrain height and normal using vmap + terrain_height_vmap = jax.vmap(lambda p: terrain.height(p[0], p[1])) + terrain_normal_vmap = jax.vmap(lambda p: terrain.normal(p[0], p[1])) + + terrain_heights = terrain_height_vmap(corners_world) + terrain_points = jnp.stack( + [corners_world[:, 0], corners_world[:, 1], terrain_heights], axis=1 + ) + + normals = terrain_normal_vmap(corners_world) + + # Distances along terrain normal + distances = jnp.einsum("ij,ij->i", corners_world - terrain_points, normals) + + # Pick 3 closest points using top_k + _, topk_idx = jax.lax.top_k(-distances, 3) + contact_points = corners_world[topk_idx] + contact_normals = normals[topk_idx] + + # Compute contact frames using vmap + W_H_C = jax.vmap(lambda p, n: _contact_frame(n, p))(contact_points, contact_normals) + + # Distances along terrain normal for the selected points + distances_top3 = distances[topk_idx] + + return distances_top3, W_H_C + + +def cylinder_plane( + terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.Matrix +) -> tuple[jtp.Vector, jtp.Matrix]: + """ + Return distances and contact frames of the 3 deepest points of a cylinder on terrain. + + Args: + terrain: The terrain object. + size: The size of the cylinder (radius, height). + W_H_L: The collision shape transform in world coordinates. + + Returns: + A tuple containing the distances from the cylinder to the plane and the pose transforms + of the contact frames. + """ + + size = size.squeeze() + r, half_h = size[0], size[1] * 0.5 + + # Cylinder pose + position = W_H_L[:3, 3] + R = W_H_L[:3, :3] + axis = R[:, 2] + + # Terrain data at cylinder XY + h = terrain.height(position[0], position[1]) + n = terrain.normal(position[0], position[1]) + plane_position = jnp.array([position[0], position[1], h]) + + # Project axis on normal and ensure direction away from plane + prjaxis = jnp.dot(n, axis) + sign = -jnp.sign(prjaxis + 1e-12) + axis, prjaxis = axis * sign, prjaxis * sign + + # Distance from cylinder centre to plane along normal + dist0 = jnp.dot(position - plane_position, n) + + # Remove component along normal from axis + vec = axis * prjaxis - n + len_vec = jnp.linalg.norm(vec) + vec = jnp.where( + len_vec < 1e-12, + R[:, 0] * r, # disk parallel to plane + vec / len_vec * r, # general case + ) + + # Project vec along normal + prjvec = jnp.dot(vec, n) + + # Scale axis by half length + ax_scaled = axis * half_h + prjaxis_h = prjaxis * half_h + + # Sideways vector for 3-point support + prjvec1 = -0.5 * prjvec + vec1 = jnp.cross(vec, ax_scaled) + vec1 = vec1 / (jnp.linalg.norm(vec1) + 1e-12) * r * jnp.sqrt(3.0) * 0.5 + + # Distances of three candidate contacts: + d1 = dist0 + prjaxis_h + prjvec + d2 = dist0 + prjaxis_h + prjvec1 + dist = jnp.array([d1, d2, d2]) + + # World position of candidates + position_c = ( + position + + ax_scaled + + jnp.array( + [ + vec - n * d1 * 0.5, + vec1 + vec * 0.5 + n * d2 * 0.5, + -vec1 + vec * 0.5 + n * d2 * 0.5, + ] + ) + ) + + # Handle case in which the cylinder lies on the disks + condition = jnp.abs(prjaxis) < 1e-3 + d3 = dist0 - prjaxis_h + prjvec + dist = jnp.where(condition, dist.at[1].set(d3), dist) + position_c = jnp.where( + condition, + position_c.at[1].set(position + vec - ax_scaled - n * d3 * 0.5), + position_c, + ) + + # Build contact frames on the three candidate points + W_H_C = jax.vmap(lambda p: _contact_frame(n, p))(position_c) + + return dist, W_H_C diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 0b08082ce..7334cea94 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import functools from collections.abc import Callable from typing import Any @@ -14,7 +15,7 @@ import jaxsim.typing as jtp from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr -from . import common, soft +from . import common, detection, soft try: from typing import Self @@ -325,16 +326,27 @@ def compute_contact_forces( joint_force_references=joint_force_references, ) - # Compute the position and linear velocities (mixed representation) of - # all collidable points belonging to the robot. - position, velocity = js.contact.collidable_point_kinematics( - model=model, data=data - ) - # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. - δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( - position, velocity, model.terrain + δ, δ̇, n̂, W_p_C, CW_ṗ_C = jax.vmap( + lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: common.compute_penetration_data( + model, + shape_transform=shape_transform, + shape_type=shape_type, + shape_size=shape_size, + link_transforms=link_transform, + link_velocities=link_velocity, + ) + )( + model.kin_dyn_parameters.contact_parameters.transform, + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + data._link_transforms[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], + data._link_velocities[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], ) # Compute the position in the constraint frame. @@ -344,13 +356,16 @@ def compute_contact_forces( a_ref, r, *_ = self._regularizers( model=model, position_constraint=position_constraint, - velocity_constraint=velocity, + velocity_constraint=CW_ṗ_C, parameters=model.contact_params, ) # Compute the transforms of the implicit frames corresponding to the # collidable points. - W_H_C = js.contact.transforms(model=model, data=data) + # The final shape will be (n_links, 3 (max_contact_points), 4, 4). + W_H_C = jax.vmap( + lambda n, p: jax.vmap(detection._contact_frame)(n, p), + )(n̂, W_p_C) with ( data.switch_velocity_representation(VelRepr.Mixed), @@ -372,15 +387,17 @@ def compute_contact_forces( # Compute the linear part of the Jacobian of the collidable points Jl_WC = jnp.vstack( jax.vmap(lambda J, δ: J * (δ > 0))( - js.contact.jacobian(model=model, data=data)[:, :3, :], δ + js.contact.jacobian(model=model, data=data)[:, :3], + jnp.concatenate(δ), ) ) # Compute the linear part of the Jacobian derivative of the collidable points J̇l_WC = jnp.vstack( jax.vmap(lambda J̇, δ: J̇ * (δ > 0))( - js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ - ), + js.contact.jacobian_derivative(model=model, data=data)[:, :3], + jnp.concatenate(δ), + ) ) # Compute the Delassus matrix for contacts (mixed representation). @@ -468,20 +485,20 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: # ====================================== # Initialize the optimized forces with a linear Hunt/Crossley model. + hunt_crossley_closure = functools.partial( + soft.SoftContacts.hunt_crossley_contact_model, + K=1e6, + D=2e3, + p=0.5, + q=0.5, + mu=0.0, + tangential_deformation=jnp.zeros(3), + ) + init_params = jax.vmap( - lambda p, v: soft.SoftContacts.hunt_crossley_contact_model( - position=p, - velocity=v, - terrain=model.terrain, - K=1e6, - D=2e3, - p=0.5, - q=0.5, - # No tangential initial forces. - mu=0.0, - tangential_deformation=jnp.zeros(3), - )[0] - )(position, velocity).flatten() + jax.vmap(hunt_crossley_closure, in_axes=(0, 0, 0, 0)), # map over contacts + in_axes=(0, 0, 0, 0), # map over links + )(δ, δ̇, CW_ṗ_C, n̂)[0].flatten() # Get the solver options. solver_options = self.solver_options @@ -509,21 +526,34 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: ) # Reshape the optimized solution to be a matrix of 3D contact forces. - CW_fl_C = solution.reshape(-1, 3) + CW_fl_per_link = solution.reshape(-1, 3, 3) + + # Transform each contact force to inertial frame + def to_inertial(force, H_C): + return ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(force), + transform=H_C, + other_representation=VelRepr.Mixed, + is_force=True, + ) - # Convert the contact forces from mixed to inertial-fixed representation. + # Compute the contact forces in inertial representation for + # each link and contact point. + # Nested vmap: inner over contacts, outer over shapes W_f_C = jax.vmap( - lambda CW_fl_C, W_H_C: ( - ModelDataWithVelocityRepresentation.other_representation_to_inertial( - array=jnp.zeros(6).at[0:3].set(CW_fl_C), - transform=W_H_C, - other_representation=VelRepr.Mixed, - is_force=True, - ) - ), - )(CW_fl_C, W_H_C) + lambda f_shape, H_shape: jax.vmap(to_inertial)(f_shape, H_shape) + )(CW_fl_per_link, W_H_C) - return W_f_C, {} + # Sum over contacts for each shape: (n_shapes, 6) + W_f_per_shape = W_f_C.sum(axis=1) + + # Accumulate forces by parent link using segment_sum + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_f_per_link = jax.ops.segment_sum( + W_f_per_shape, body_indices, num_segments=model.number_of_links() + ) + + return W_f_per_link, {} @staticmethod def _regularizers( @@ -563,17 +593,12 @@ def _regularizers( ) ) - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - # Compute the 6D inertia matrices of all links. - M_L = js.model.link_spatial_inertia_matrices(model=model) + M_L_all = js.model.link_spatial_inertia_matrices(model=model)[:, :3, :3] + + # Index M_L by the body (parent link) of each collision shape + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + M_L = M_L_all[body_indices] def imp_aref( pos: jtp.Vector, @@ -621,7 +646,7 @@ def imp_aref( def compute_row( *, - link_idx: jtp.Int, + M_Li: jtp.Matrix, pos: jtp.Vector, vel: jtp.Vector, ) -> tuple[jtp.Vector, jtp.Matrix, jtp.Vector, jtp.Vector]: @@ -630,11 +655,7 @@ def compute_row( ξ, a_ref, K, D = imp_aref(pos=pos, vel=vel) # Compute the regularization term. - R = ( - (2 * μ**2 * (1 - ξ) / (ξ + 1e-12)) - * (1 + μ**2) - @ jnp.linalg.inv(M_L[link_idx, :3, :3]) - ) + R = (2 * μ**2 * (1 - ξ) / (ξ + 1e-12)) * (1 + μ**2) @ jnp.linalg.inv(M_Li) # Return the computed values, setting them to zero in case of no contact. is_active = (pos.dot(pos) > 0).astype(float) @@ -643,13 +664,11 @@ def compute_row( ) a_ref, R, K, D = jax.tree.map( - f=jnp.concatenate, - tree=( - *jax.vmap(compute_row)( - link_idx=parent_link_idx_of_enabled_collidable_points, - pos=position_constraint, - vel=velocity_constraint, - ), + jnp.ravel, + jax.vmap(compute_row)( + M_Li=M_L, + pos=position_constraint, + vel=velocity_constraint, ), ) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 961133b9c..3a649ab24 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -1,7 +1,6 @@ from __future__ import annotations import dataclasses -import functools import jax import jax.numpy as jnp @@ -11,7 +10,6 @@ import jaxsim.math import jaxsim.typing as jtp from jaxsim import logging -from jaxsim.terrain import Terrain from . import common @@ -197,12 +195,13 @@ def update_velocity_after_impact( return data @staticmethod - @functools.partial(jax.jit, static_argnames=("terrain",)) + @jax.jit def hunt_crossley_contact_model( - position: jtp.VectorLike, + penetration: jtp.VectorLike, + penetration_rate: jtp.VectorLike, velocity: jtp.VectorLike, + normal: jtp.VectorLike, tangential_deformation: jtp.VectorLike, - terrain: Terrain, K: jtp.FloatLike, D: jtp.FloatLike, mu: jtp.FloatLike, @@ -213,10 +212,11 @@ def hunt_crossley_contact_model( Compute the contact force using the Hunt/Crossley model. Args: - position: The position of the collidable point. - velocity: The velocity of the collidable point. - tangential_deformation: The material deformation of the collidable point. - terrain: The terrain model. + penetration: The penetration of the collision point. + penetration_rate: The penetration rate of the collision point. + velocity: The velocity of the contact point. + normal: The terrain normal at the contact point. + tangential_deformation: The material deformation of the collidable shape. K: The stiffness parameter. D: The damping parameter of the soft contacts model. mu: The static friction coefficient. @@ -232,17 +232,14 @@ def hunt_crossley_contact_model( material deformation. """ - # Convert the input vectors to arrays. - W_p_C = jnp.array(position, dtype=float).squeeze() - W_ṗ_C = jnp.array(velocity, dtype=float).squeeze() - m = jnp.array(tangential_deformation, dtype=float).squeeze() - - # Use symbol for the static friction. + # Use symbols for input parameters. + W_ṗ_C = velocity + m = tangential_deformation + δ = penetration + δ̇ = penetration_rate + n̂ = normal μ = mu - # Compute the penetration depth, its rate, and the considered terrain normal. - δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain) - # There are few operations like computing the norm of a vector with zero length # or computing the square root of zero that are problematic in an AD context. # To avoid these issues, we introduce a small tolerance ε to their arguments @@ -343,53 +340,63 @@ def hunt_crossley_contact_model( return CW_fl, ṁ @staticmethod - @functools.partial(jax.jit, static_argnames=("terrain",)) + @jax.jit def compute_contact_force( - position: jtp.VectorLike, - velocity: jtp.VectorLike, - tangential_deformation: jtp.VectorLike, + penetration: jtp.Float, + penetration_rate: jtp.Float, + position: jtp.Vector, + velocity: jtp.Vector, + normal: jtp.Vector, + tangential_deformation: jtp.Vector, parameters: SoftContactsParams, - terrain: Terrain, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the contact force. Args: - position: The position of the collidable point. - velocity: The velocity of the collidable point. - tangential_deformation: The material deformation of the collidable point. + penetration: The penetration of the collision point. + penetration_rate: The penetration rate of the collision point. + position: The position of the contact point. + velocity: The velocity of the contact point. + normal: The terrain normal at the contact point. + tangential_deformation: The material deformation of the collidable shape. parameters: The parameters of the soft contacts model. - terrain: The terrain model. Returns: A tuple containing the computed contact force and the derivative of the material deformation. """ - CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( - position=position, - velocity=velocity, - tangential_deformation=tangential_deformation, - terrain=terrain, - K=parameters.K, - D=parameters.D, - mu=parameters.mu, - p=parameters.p, - q=parameters.q, + CW_fl, ṁ = jax.vmap( + SoftContacts.hunt_crossley_contact_model, + in_axes=(0, 0, 0, 0, None, None, None, None, None, None), + )( + penetration, + penetration_rate, + velocity, + normal, + tangential_deformation, + parameters.K, + parameters.D, + parameters.mu, + parameters.p, + parameters.q, ) # Pack a mixed 6D force. - CW_f = jnp.hstack([CW_fl, jnp.zeros(3)]) + CW_f = jax.vmap(lambda f: jnp.hstack([f, jnp.zeros(3)]))(f=CW_fl) # Compute the 6D force transform from the mixed to the inertial-fixed frame. - W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation( - translation=jnp.array(position), inverse=True - ).T + W_Xf_CW = jax.vmap( + lambda W_p_C: jaxsim.math.Adjoint.from_quaternion_and_translation( + translation=jnp.array(W_p_C), inverse=True + ).T + )(W_p_C=position) # Compute the 6D force in the inertial-fixed frame. - W_f = W_Xf_CW @ CW_f + W_f = jnp.einsum("...ij,...j->...i", W_Xf_CW, CW_f) - return W_f, ṁ + return jnp.sum(W_f, axis=0), jnp.mean(ṁ, axis=0) @staticmethod @jax.jit @@ -409,40 +416,50 @@ def compute_contact_forces( second element a dictionary with derivative of the material deformation. """ - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - # Compute the position and linear velocities (mixed representation) of - # all the collidable points belonging to the robot and extract the ones - # for the enabled collidable points. - W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) - - # Extract the material deformation corresponding to the collidable points. - m = ( - data.contact_state["tangential_deformation"] - if "tangential_deformation" in data.contact_state - else jnp.zeros_like(W_p_C) + # all the collidable shapes belonging to the robot and extract the ones + # for the enabled collidable shapes. + δ, δ̇, n̂, W_p_C, CW_ṗ_C = jax.vmap( + lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: common.compute_penetration_data( + model, + shape_transform=shape_transform, + shape_type=shape_type, + shape_size=shape_size, + link_transforms=link_transform, + link_velocities=link_velocity, + ) + )( + model.kin_dyn_parameters.contact_parameters.transform, + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + data._link_transforms[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], + data._link_velocities[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], ) - m_enabled = m[indices_of_enabled_collidable_points] + # Extract the material deformation corresponding to the collidable shapes. + m = data.contact_state["tangential_deformation"] - # Initialize the tangential deformation rate array for every collidable point. + # Initialize the tangential deformation rate array for every collidable shape. ṁ = jnp.zeros_like(m) - # Compute the contact forces only for the enabled collidable points. + # Compute the contact forces for all the collidable shapes. # Since we treat them as independent, we can vmap the computation. - W_f, ṁ_enabled = jax.vmap( - lambda p, v, m: SoftContacts.compute_contact_force( - position=p, - velocity=v, - tangential_deformation=m, - parameters=model.contact_params, - terrain=model.terrain, - ) - )(W_p_C, W_ṗ_C, m_enabled) - - ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) + # We exploit two levels of vmap to vectorize over both the shapes and the points. + # The outer vmap vectorizes over the shapes, while the inner vmap vectorizes + # over the maximum points (3) belonging to each shape. + W_f_per_shape, ṁ = jax.vmap( + SoftContacts.compute_contact_force, + in_axes=(0, 0, 0, 0, 0, 0, None), # vectorize over shapes + )(δ, δ̇, W_p_C, CW_ṗ_C, n̂, m, model.contact_params) + + # Accumulate forces by parent link using segment_sum + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_f = jax.ops.segment_sum( + W_f_per_shape, body_indices, num_segments=model.number_of_links() + ) - return W_f, {"m_dot": ṁ} + return W_f, {"m_dot": ṁ} diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py index d0b881ceb..67c7bcb98 100644 --- a/src/jaxsim/utils/__init__.py +++ b/src/jaxsim/utils/__init__.py @@ -1,5 +1,21 @@ +import dataclasses +from typing import ClassVar + from jax_dataclasses._copy_and_mutate import _Mutability as Mutability from .jaxsim_dataclass import JaxsimDataclass from .tracing import not_tracing, tracing from .wrappers import HashedNumpyArray, HashlessObject + + +# TODO (flferretti): Definitely not the best place for this +@dataclasses.dataclass(frozen=True) +class CollidableShapeType: + """ + Enum representing the types of collidable shapes. + """ + + Unsupported: ClassVar[int] = -1 + Box: ClassVar[int] = 0 + Cylinder: ClassVar[int] = 1 + Sphere: ClassVar[int] = 2 diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index d83c2f5bc..eb7ab8fc4 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -1,5 +1,6 @@ import jax import jax.numpy as jnp +import numpy as np import pytest import rod @@ -22,53 +23,34 @@ def test_contact_kinematics( velocity_representation=velocity_representation, ) - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - # ===== # Tests # ===== - # Compute the pose of the implicit contact frame associated to the collidable points + # Compute the pose of the implicit contact frame associated to the collidable shapes # and the transforms of all links. W_H_C = js.contact.transforms(model=model, data=data) - W_H_L = data._link_transforms - - # Check that the orientation of the implicit contact frame matches with the - # orientation of the link to which the contact point is attached. - for contact_idx, index_of_parent_link in enumerate( - parent_link_idx_of_enabled_collidable_points - ): - assert W_H_C[contact_idx, 0:3, 0:3] == pytest.approx( - W_H_L[index_of_parent_link][0:3, 0:3] - ) # Check that the origin of the implicit contact frame is located over the - # collidable point. - W_p_C = js.contact.collidable_point_positions(model=model, data=data) - assert W_p_C == pytest.approx(W_H_C[:, 0:3, 3]) + # collidable shape. + W_p_C = js.contact.contact_point_positions(model=model, data=data) + assert W_p_C == pytest.approx(W_H_C[:, :, 0:3, 3]) - # Compute the velocity of the collidable point. + # Compute the velocity of the collidable shape. # This quantity always matches with the linear component of the mixed 6D velocity - # of the implicit frame associated to the collidable point. - W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) + # of the implicit frame associated to the collidable shape. + W_ṗ_C = js.contact.contact_point_velocities(model=model, data=data) - # Compute the velocity of the collidable point using the contact Jacobian. + # Compute the velocity of the collidable shape using the contact Jacobian. ν = data.generalized_velocity CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) CW_vl_WC = jnp.einsum("c6g,g->c6", CW_J_WC, ν)[:, 0:3] # Compare the two velocities. - assert W_ṗ_C == pytest.approx(CW_vl_WC) + assert jnp.concatenate(W_ṗ_C) == pytest.approx(CW_vl_WC) -def test_collidable_point_jacobians( +def test_contact_point_jacobians( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, @@ -85,19 +67,19 @@ def test_collidable_point_jacobians( # Tests # ===== - # Compute the velocity of the collidable points with a RBDA. + # Compute the velocity of the collidable shapes with a RBDA. # This function always returns the linear part of the mixed velocity of the - # implicit frame C corresponding to the collidable point. - W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) + # implicit frame C corresponding to the collidable shape. + W_ṗ_C = js.contact.contact_point_velocities(model=model, data=data) # Compute the generalized velocity and the free-floating Jacobian of the frame C. ν = data.generalized_velocity CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) - # Compute the velocity of the collidable points using the Jacobians. + # Compute the velocity of the collidable shapes using the Jacobians. v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CW_J_WC, ν) - assert W_ṗ_C == pytest.approx(v_WC_from_jax[:, 0:3]) + assert jnp.concatenate(W_ṗ_C) == pytest.approx(v_WC_from_jax[:, 0:3]) def test_contact_jacobian_derivative( @@ -115,22 +97,24 @@ def test_contact_jacobian_derivative( velocity_representation=velocity_representation, ) - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) + body_indices = np.array(model.kin_dyn_parameters.contact_parameters.body) - # Extract the parent link names and the poses of the contact points. - parent_link_names = js.link.idxs_to_names( - model=model, - link_indices=jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points], - ) + # Get link transforms for each collision shape + W_H_L = data._link_transforms[body_indices] + + # Get contact point positions (shape: num_collision_shapes, 3, 3) + W_p_C = js.contact.contact_point_positions(model=model, data=data) - L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_points - ] + # Transform contact points from world to link frame + # For each collision shape, transform its 3 contact points + def transform_to_link_frame(W_H_L_i, W_p_Ci): + """Transform 3 contact points from world to link frame.""" + + L_H_W = jnp.linalg.inv(W_H_L_i) + return jax.vmap(lambda p: (L_H_W @ jnp.hstack([p, 1.0]))[:3])(W_p_Ci) + + # Apply to all collision shapes: shape (num_collision_shapes, 3, 3) + L_p_Ci = jax.vmap(transform_to_link_frame)(W_H_L, W_p_C) # ===== # Tests @@ -139,19 +123,22 @@ def test_contact_jacobian_derivative( # Load the model in ROD. rod_model = rod.Sdf.load(sdf=model.built_from).model - # Add dummy frames on the contact points. - for idx, link_name, L_p_C in zip( - indices_of_enabled_collidable_points, parent_link_names, L_p_Ci, strict=True + for shape_idx, (link_idx, points) in enumerate( + zip(body_indices, L_p_Ci, strict=True) ): - rod_model.add_frame( - frame=rod.Frame( - name=f"contact_point_{idx}", - attached_to=link_name, - pose=rod.Pose( - relative_to=link_name, pose=jnp.zeros(shape=(6,)).at[0:3].set(L_p_C) + link_name = model.link_names()[link_idx] + + for j, p in enumerate(points): + rod_model.add_frame( + frame=rod.Frame( + name=f"contact_shape_{shape_idx}_{j}", + attached_to=link_name, + pose=rod.Pose( + relative_to=link_name, + pose=jnp.zeros((6,)).at[0:3].set(p), + ), ), - ), - ) + ) # Rebuild the JaxSim model. model_with_frames = js.model.JaxSimModel.build_from_model_description( @@ -173,17 +160,17 @@ def test_contact_jacobian_derivative( velocity_representation=velocity_representation, ) - # Extract the indexes of the frames attached to the contact points. + # Extract the indexes of the frames attached to the contact shapes. + num_collision_shapes = len(model.kin_dyn_parameters.contact_parameters.body) frame_idxs = js.frame.names_to_idxs( model=model_with_frames, frame_names=( - f"contact_point_{idx}" for idx in indices_of_enabled_collidable_points + f"contact_shape_{shape_idx}_{j}" + for shape_idx in range(num_collision_shapes) + for j in range(3) ), ) - # Check that the number of frames is correct. - assert len(frame_idxs) == len(parent_link_names) - # Compute the contact Jacobian derivative. J̇_WC = js.contact.jacobian_derivative( model=model_with_frames, data=data_with_frames diff --git a/tests/test_meshes.py b/tests/test_meshes.py deleted file mode 100644 index d9bd66dcc..000000000 --- a/tests/test_meshes.py +++ /dev/null @@ -1,103 +0,0 @@ -import trimesh - -from jaxsim.parsers.rod import meshes - - -def test_mesh_wrapping_vertex_extraction(): - """ - Test the vertex extraction method on different meshes. - - 1. A simple box. - 2. A sphere. - """ - - # Test 1: A simple box. - # First, create a box with origin at (0,0,0) and extents (3,3,3), - # i.e. points span from -1.5 to 1.5 on the axis. - mesh = trimesh.creation.box( - extents=[3.0, 3.0, 3.0], - ) - points = meshes.extract_points_vertices(mesh=mesh) - assert len(points) == len(mesh.vertices) - - # Test 2: A sphere. - # The sphere is centered at the origin and has a radius of 1.0. - mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) - points = meshes.extract_points_vertices(mesh=mesh) - assert len(points) == len(mesh.vertices) - - -def test_mesh_wrapping_aap(): - """ - Test the AAP wrapping method on different meshes. - - 1. A simple box - 1.1: Remove all points above x=0.0 - 1.2: Remove all points below y=0.0 - 2. A sphere - """ - - # Test 1.1: Remove all points above x=0.0. - # The expected result is that the number of points is halved. - # First, create a box with origin at (0,0,0) and extents (3,3,3), - # i.e. points span from -1.5 to 1.5 on the axis. - mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0]) - points = meshes.extract_points_aap(mesh=mesh, axis="x", lower=0.0) - assert len(points) == len(mesh.vertices) // 2 - assert all(points[:, 0] > 0.0) - - # Test 1.2: Remove all points below y=0.0. - # The expected result is that the number of points is halved. - points = meshes.extract_points_aap(mesh=mesh, axis="y", upper=0.0) - assert len(points) == len(mesh.vertices) // 2 - assert all(points[:, 1] < 0.0) - - # Test 2: A sphere. - # The sphere is centered at the origin and has a radius of 1.0. - # Points are expected to be halved. - mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) - - # Remove all points above y=0.0. - points = meshes.extract_points_aap(mesh=mesh, axis="y", lower=0.0) - assert all(points[:, 1] >= 0.0) - assert len(points) < len(mesh.vertices) - - -def test_mesh_wrapping_points_over_axis(): - """ - Test the points over axis method on different meshes. - - 1. A simple box - 1.1: Select 10 points from the lower end of the x-axis - 1.2: Select 10 points from the higher end of the y-axis - 2. A sphere - """ - - # Test 1.1: Remove 10 points from the lower end of the x-axis. - # First, create a box with origin at (0,0,0) and extents (3,3,3), - # i.e. points span from -1.5 to 1.5 on the axis. - mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0]) - points = meshes.extract_points_select_points_over_axis( - mesh=mesh, axis="x", direction="lower", n=4 - ) - assert len(points) == 4 - assert all(points[:, 0] < 0.0) - - # Test 1.2: Select 10 points from the higher end of the y-axis. - points = meshes.extract_points_select_points_over_axis( - mesh=mesh, axis="y", direction="higher", n=4 - ) - assert len(points) == 4 - assert all(points[:, 1] > 0.0) - - # Test 2: A sphere. - # The sphere is centered at the origin and has a radius of 1.0. - mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) - sphere_n_vertices = len(mesh.vertices) - - # Select 10 points from the higher end of the z-axis. - points = meshes.extract_points_select_points_over_axis( - mesh=mesh, axis="z", direction="higher", n=sphere_n_vertices // 2 - ) - assert len(points) == sphere_n_vertices // 2 - assert all(points[:, 2] >= 0.0) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 79ef1eed2..f1c47ba39 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -194,7 +194,7 @@ def test_simulation_with_soft_contacts( model = jaxsim_model_box - # Define the maximum penetration of each collidable point at steady state. + # Define the maximum penetration at steady state. max_penetration = 0.001 with model.editable(validate=False) as model: @@ -202,21 +202,11 @@ def test_simulation_with_soft_contacts( model.contact_model = jaxsim.rbda.contacts.SoftContacts.build() model.contact_params = js.contact.estimate_good_contact_parameters( model=model, - number_of_active_collidable_points_steady_state=4, static_friction_coefficient=1.0, damping_ratio=1.0, max_penetration=max_penetration, ) - # Enable a subset of the collidable points. - enabled_collidable_points_mask = np.zeros( - len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool - ) - enabled_collidable_points_mask[[0, 1, 2, 3]] = True - model.kin_dyn_parameters.contact_parameters.enabled = tuple( - enabled_collidable_points_mask.tolist() - ) - assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 # Check jaxsim_model_box@conftest.py.