diff --git a/docs/modules/api.rst b/docs/modules/api.rst index 8a1e884d7..8b4cd3bc6 100644 --- a/docs/modules/api.rst +++ b/docs/modules/api.rst @@ -21,12 +21,6 @@ Contact .. automodule:: jaxsim.api.contact :members: -Soft Contacts -""""""""""""" - -.. automodule:: jaxsim.api.soft_contact - :members: - KinDynParameters ~~~~~~~~~~~~~~~~ diff --git a/docs/modules/rbda.rst b/docs/modules/rbda.rst index c3a17d888..f1fdc1da1 100644 --- a/docs/modules/rbda.rst +++ b/docs/modules/rbda.rst @@ -28,6 +28,12 @@ Collision Detection .. automodule:: jaxsim.rbda.collidable_points :members: +Contact Models +~~~~~~~~~~~~~~ + +.. automodule:: jaxsim.rbda.soft_contacts + :members: + Composite Rigid Body Algorithm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 7b593e897..c97f7eb30 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -1,17 +1,14 @@ from __future__ import annotations -import abc -import dataclasses import functools import jax import jax.numpy as jnp -import jax_dataclasses import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp -from jaxsim.utils import JaxsimDataclass +from jaxsim.rbda.contacts.soft_contacts import SoftContacts, SoftContactsParams from .common import VelRepr @@ -135,7 +132,6 @@ def collidable_point_dynamics( `C[W] = ({}^W \mathbf{p}_C, [W])`. This is convenient for integration purpose. Instead, the 6D forces are returned in the active representation. """ - from .soft_contacts import SoftContacts # Compute the position and linear velocities (mixed representation) of # all collidable points belonging to the robot. @@ -154,7 +150,7 @@ def collidable_point_dynamics( # Note that the material deformation rate is always returned in the mixed frame # C[W] = (W_p_C, [W]). This is convenient for integration purpose. W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)( - W_p_Ci, W_ṗ_Ci, data.state.contact_state.tangential_deformation + W_p_Ci, W_ṗ_Ci, data.state.contacts_state.tangential_deformation ) case _: @@ -226,7 +222,7 @@ def estimate_good_soft_contacts_parameters( number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, max_penetration: jtp.FloatLike | None = None, -) -> js.soft_contacts.SoftContactsParams: +) -> SoftContactsParams: """ Estimate good soft contacts parameters for the given model. @@ -250,13 +246,14 @@ def estimate_good_soft_contacts_parameters( The user is encouraged to fine-tune the parameters based on the specific application. """ + from jaxsim.rbda.contacts.soft_contacts import SoftContactsParams def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: """""" zero_data = js.data.JaxSimModelData.build( model=model, - contacts_params=js.soft_contacts.SoftContactsParams(), + contacts_params=SoftContactsParams(), ) W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2] @@ -275,7 +272,7 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: nc = number_of_active_collidable_points_steady_state - sc_parameters = js.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model( + sc_parameters = SoftContactsParams.build_default_from_jaxsim_model( model=model, standard_gravity=standard_gravity, static_friction_coefficient=static_friction_coefficient, @@ -368,12 +365,10 @@ def jacobian( # Adjust the output representation. match output_vel_repr: - case VelRepr.Inertial: O_J_WC = W_J_WC case VelRepr.Body: - W_H_C = transforms(model=model, data=data) def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: @@ -386,11 +381,9 @@ def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC) case VelRepr.Mixed: - W_H_C = transforms(model=model, data=data) def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: - W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) CW_X_W = jaxsim.math.Adjoint.from_transform( @@ -406,96 +399,3 @@ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: raise ValueError(output_vel_repr) return O_J_WC - - -@jax_dataclasses.pytree_dataclass -class ContactsState(JaxsimDataclass, abc.ABC): - """ - Abstract class storing the state of the contacts model. - """ - - @classmethod - def build(cls, **kwargs) -> ContactsState: - """ - Build the contact state object. - Returns: - The contact state object. - """ - - return cls(**kwargs) - - @classmethod - def zero(cls, **kwargs) -> ContactsState: - """ - Build a zero contact state. - Returns: - The zero contact state. - """ - - return cls.build(**kwargs) - - def valid(self, **kwargs) -> bool: - """ - Check if the contacts state is valid. - """ - - return True - - -@jax_dataclasses.pytree_dataclass -class ContactsParams(JaxsimDataclass, abc.ABC): - """ - Abstract class representing the parameters of a contact model. - """ - - @abc.abstractmethod - def build(self) -> ContactsParams: - """ - Create a `ContactsParams` instance with specified parameters. - Returns: - The `ContactsParams` instance. - """ - - raise NotImplementedError - - def valid(self, *args, **kwargs) -> bool: - """ - Check if the parameters are valid. - Returns: - True if the parameters are valid, False otherwise. - """ - - return True - - -@jax_dataclasses.pytree_dataclass -class ContactModel(abc.ABC): - """ - Abstract class representing a contact model. - Attributes: - parameters: The parameters of the contact model. - terrain: The terrain model. - """ - - parameters: ContactsParams = dataclasses.field(default_factory=ContactsParams) - terrain: jaxsim.terrain.Terrain = dataclasses.field( - default_factory=jaxsim.terrain.FlatTerrain - ) - - @abc.abstractmethod - def contact_model( - self, - position: jtp.Vector, - velocity: jtp.Vector, - **kwargs, - ) -> tuple[jtp.Vector, jtp.Vector]: - """ - Compute the contact forces. - Args: - position: The position of the collidable point. - velocity: The velocity of the collidable point. - Returns: - A tuple containing the contact force and additional information. - """ - - raise NotImplementedError diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 8473cb269..9bbb48bcd 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -14,14 +14,13 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim.math import Quaternion +from jaxsim.rbda.contacts.soft_contacts import SoftContacts from jaxsim.utils import Mutability from jaxsim.utils.tracing import not_tracing from . import common from .common import VelRepr -from .contact import ContactsParams, ContactsState from .ode_data import ODEState -from .soft_contacts import SoftContacts try: from typing import Self @@ -39,7 +38,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): gravity: jtp.Array - contacts_params: ContactsParams = dataclasses.field(repr=False) + contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False) time_ns: jtp.Int = dataclasses.field( default_factory=lambda: jnp.array(0, dtype=jnp.uint64) @@ -114,8 +113,8 @@ def build( base_angular_velocity: jtp.Vector | None = None, joint_velocities: jtp.Vector | None = None, standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity, - contacts_state: ContactsState | None = None, - contacts_params: ContactsParams | None = None, + contacts_state: jaxsim.rbda.ContactsState | None = None, + contacts_params: jaxsim.rbda.ContactsParams | None = None, velocity_representation: VelRepr = VelRepr.Inertial, time: jtp.FloatLike | None = None, ) -> JaxSimModelData: @@ -658,7 +657,10 @@ def reset_base_linear_velocity( return self.reset_base_velocity( base_velocity=jnp.hstack( - [linear_velocity.squeeze(), self.base_velocity()[3:6]] + [ + linear_velocity.squeeze(), + self.base_velocity()[3:6], + ] ), velocity_representation=velocity_representation, ) @@ -686,7 +688,10 @@ def reset_base_angular_velocity( return self.reset_base_velocity( base_velocity=jnp.hstack( - [self.base_velocity()[0:3], angular_velocity.squeeze()] + [ + self.base_velocity()[0:3], + angular_velocity.squeeze(), + ] ), velocity_representation=velocity_representation, ) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index f9ffbc0a7..579bee865 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -184,7 +184,6 @@ def κb(link_index: jtp.IntLike) -> jtp.Vector: carry0 = κb, link_index def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]: - κb, active_link_index = carry κb, active_link_index = jax.lax.cond( @@ -226,14 +225,12 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]: ) def __eq__(self, other: KynDynParameters) -> bool: - if not isinstance(other, KynDynParameters): return False return hash(self) == hash(other) def __hash__(self) -> int: - return hash( ( hash(self.number_of_links()), @@ -643,7 +640,6 @@ def build_from_inertial_parameters( def build_from_flat_parameters( index: jtp.IntLike, parameters: jtp.VectorLike ) -> LinkParameters: - index = jnp.array(index).squeeze().astype(int) m = jnp.array(parameters[0]).squeeze().astype(float) @@ -668,7 +664,11 @@ def flat_parameters(params: LinkParameters) -> jtp.Vector: return ( jnp.hstack( - [params.mass, params.center_of_mass.squeeze(), params.inertia_elements] + [ + params.mass, + params.center_of_mass.squeeze(), + params.inertia_elements, + ] ) .squeeze() .astype(float) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index a49959e96..7ec97b9ca 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -20,8 +20,6 @@ from jaxsim.utils import JaxsimDataclass, Mutability, wrappers from .common import VelRepr -from .contact import ContactModel -from .soft_contacts import SoftContacts @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) @@ -52,7 +50,7 @@ class JaxSimModel(JaxsimDataclass): def description(self) -> jaxsim.parsers.descriptions.ModelDescription: return self._description.get() - contact_model: ContactModel | None = dataclasses.field( + contact_model: jaxsim.rbda.ContactModel | None = dataclasses.field( default=None, repr=False, compare=False, hash=False ) @@ -89,7 +87,7 @@ def build_from_model_description( model_name: str | None = None, *, terrain: jaxsim.terrain.Terrain | None = None, - contact_model: ContactModel | None = None, + contact_model: jaxsim.rbda.ContactModel | None = None, is_urdf: bool | None = None, considered_joints: Sequence[str] | None = None, ) -> JaxSimModel: @@ -116,6 +114,7 @@ def build_from_model_description( """ import jaxsim.parsers.rod + from jaxsim.rbda.contacts.soft_contacts import SoftContacts # Parse the input resource (either a path to file or a string with the URDF/SDF) # and build the -intermediate- model description @@ -153,7 +152,7 @@ def build( model_name: str | None = None, *, terrain: jaxsim.terrain.Terrain | None = None, - contact_model: ContactModel | None = None, + contact_model: jaxsim.rbda.ContactModel | None = None, ) -> JaxSimModel: """ Build a Model object from an intermediate model description. @@ -172,6 +171,7 @@ def build( Returns: The built Model object. """ + from jaxsim.rbda.contacts.soft_contacts import SoftContacts # Set the model name (if not provided, use the one from the model description) model_name = model_name if model_name is not None else model_description.name diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 5d6463c85..86f617d1d 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -132,7 +132,7 @@ def system_velocity_dynamics( W_f_Ci = None # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}. - ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float) + ṁ = jnp.zeros_like(data.state.contacts_state.tangential_deformation).astype(float) if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the 6D forces applied to each collidable point and the diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index 9cae857a5..d6a74b902 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -1,11 +1,15 @@ from __future__ import annotations +import importlib + import jax.numpy as jnp import jax_dataclasses import jaxsim.api as js import jaxsim.typing as jtp -from jaxsim.api.soft_contacts import SoftContactsState +from jaxsim import logging +from jaxsim.rbda import ContactsState +from jaxsim.rbda.contacts.soft_contacts import SoftContactsState from jaxsim.utils import JaxsimDataclass # ============================================================================= @@ -117,11 +121,11 @@ class ODEState(JaxsimDataclass): Attributes: physics_model: The state of the physics model. - soft_contacts: The state of the soft-contacts model. + contacts_state: The state of the contacts model. """ physics_model: PhysicsModelState - soft_contacts: SoftContactsState + contacts_state: ContactsState @staticmethod def build_from_jaxsim_model( @@ -159,6 +163,15 @@ def build_from_jaxsim_model( `JaxSimModel` and initialized to zero. """ + # Get the contact model from the `JaxSimModel` + prefix = type(model.contact_model).__name__.split("Contact")[0] + + if prefix: + module_name = f"{prefix.lower()}_contacts" + class_name = f"{prefix.capitalize()}ContactsState" + else: + raise ValueError("Unable to determine contact state class prefix.") + return ODEState.build( model=model, physics_model_state=PhysicsModelState.build_from_jaxsim_model( @@ -170,24 +183,31 @@ def build_from_jaxsim_model( base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, ), - soft_contacts_state=SoftContactsState.build_from_jaxsim_model( + contacts_state=getattr( + importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"), + class_name, + ).build_from_jaxsim_model( model=model, - tangential_deformation=tangential_deformation, + **( + dict(tangential_deformation=tangential_deformation) + if tangential_deformation is not None + else dict() + ), ), ) @staticmethod def build( physics_model_state: PhysicsModelState | None = None, - soft_contacts_state: SoftContactsState | None = None, + contacts_state: ContactsState | None = None, model: js.model.JaxSimModel | None = None, ) -> ODEState: """ - Build an `ODEState` from a `PhysicsModelState` and a `SoftContactsState`. + Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`. Args: physics_model_state: The state of the physics model. - soft_contacts_state: The state of the soft-contacts model. + contacts_state: The state of the contacts model. model: The `JaxSimModel` associated with the ODE state. Returns: @@ -200,14 +220,34 @@ def build( else PhysicsModelState.zero(model=model) ) - soft_contacts_state = ( - soft_contacts_state - if soft_contacts_state is not None + # Get the contact model from the `JaxSimModel` + try: + prefix = type(model.contact_model).__name__.split("Contact")[0] + except AttributeError: + logging.warning( + "Unable to determine contact state class prefix. Using default soft contacts." + ) + prefix = "Soft" + + module_name = f"{prefix.lower()}_contacts" + class_name = f"{prefix.capitalize()}ContactsState" + + try: + state_cls = getattr( + importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"), + class_name, + ) + except ImportError as e: + raise e + + contacts_state = ( + contacts_state + if contacts_state is not None else SoftContactsState.zero(model=model) ) return ODEState( - physics_model=physics_model_state, soft_contacts=soft_contacts_state + physics_model=physics_model_state, contacts_state=contacts_state ) @staticmethod @@ -237,7 +277,7 @@ def valid(self, model: js.model.JaxSimModel) -> bool: `True` if the ODE state is valid for the given model, `False` otherwise. """ - return self.physics_model.valid(model=model) and self.soft_contacts.valid( + return self.physics_model.valid(model=model) and self.contacts_state.valid( model=model ) diff --git a/src/jaxsim/api/soft_contacts.py b/src/jaxsim/api/soft_contacts.py deleted file mode 100644 index 2ea418e6a..000000000 --- a/src/jaxsim/api/soft_contacts.py +++ /dev/null @@ -1,443 +0,0 @@ -from __future__ import annotations - -import dataclasses - -import jax -import jax.numpy as jnp -import jax_dataclasses - -import jaxsim.api as js -import jaxsim.typing as jtp -from jaxsim.math import Skew, StandardGravity -from jaxsim.terrain import FlatTerrain, Terrain - -from .contact import ContactModel, ContactsParams, ContactsState - - -@jax_dataclasses.pytree_dataclass -class SoftContactsParams(ContactsParams): - """Parameters of the soft contacts model.""" - - K: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(1e6, dtype=float) - ) - - D: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(2000, dtype=float) - ) - - mu: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.5, dtype=float) - ) - - def __hash__(self) -> int: - - from jaxsim.utils.wrappers import HashedNumpyArray - - return hash( - ( - HashedNumpyArray.hash_of_array(self.K), - HashedNumpyArray.hash_of_array(self.D), - HashedNumpyArray.hash_of_array(self.mu), - ) - ) - - def __eq__(self, other: SoftContactsParams) -> bool: - - if not isinstance(other, SoftContactsParams): - return NotImplemented - - return hash(self) == hash(other) - - @staticmethod - def build( - K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5 - ) -> SoftContactsParams: - """ - Create a SoftContactsParams instance with specified parameters. - - Args: - K: The stiffness parameter. - D: The damping parameter of the soft contacts model. - mu: The static friction coefficient. - - Returns: - A SoftContactsParams instance with the specified parameters. - """ - - return SoftContactsParams( - K=jnp.array(K, dtype=float), - D=jnp.array(D, dtype=float), - mu=jnp.array(mu, dtype=float), - ) - - @staticmethod - def build_default_from_jaxsim_model( - model: js.model.JaxSimModel, - *, - standard_gravity: jtp.FloatLike = StandardGravity, - static_friction_coefficient: jtp.FloatLike = 0.5, - max_penetration: jtp.FloatLike = 0.001, - number_of_active_collidable_points_steady_state: jtp.IntLike = 1, - damping_ratio: jtp.FloatLike = 1.0, - ) -> SoftContactsParams: - """ - Create a SoftContactsParams instance with good default parameters. - - Args: - model: The target model. - standard_gravity: The standard gravity constant. - static_friction_coefficient: - The static friction coefficient between the model and the terrain. - max_penetration: The maximum penetration depth. - number_of_active_collidable_points_steady_state: - The number of contacts supporting the weight of the model - in steady state. - damping_ratio: The ratio controlling the damping behavior. - - Returns: - A `SoftContactsParams` instance with the specified parameters. - - Note: - The `damping_ratio` parameter allows to operate on the following conditions: - - ξ > 1.0: over-damped - - ξ = 1.0: critically damped - - ξ < 1.0: under-damped - """ - - # Use symbols for input parameters - ξ = damping_ratio - δ_max = max_penetration - μc = static_friction_coefficient - - # Compute the total mass of the model - m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() - - # Rename the standard gravity - g = standard_gravity - - # Compute the average support force on each collidable point - f_average = m * g / number_of_active_collidable_points_steady_state - - # Compute the stiffness to get the desired steady-state penetration - K = f_average / jnp.power(δ_max, 3 / 2) - - # Compute the damping using the damping ratio - critical_damping = 2 * jnp.sqrt(K * m) - D = ξ * critical_damping - - return SoftContactsParams.build(K=K, D=D, mu=μc) - - -@jax_dataclasses.pytree_dataclass -class SoftContacts(ContactModel): - """Soft contacts model.""" - - parameters: SoftContactsParams = dataclasses.field( - default_factory=SoftContactsParams - ) - - terrain: Terrain = dataclasses.field(default_factory=FlatTerrain) - - def contact_model( - self, - position: jtp.Vector, - velocity: jtp.Vector, - tangential_deformation: jtp.Vector, - ) -> tuple[jtp.Vector, jtp.Vector]: - """ - Compute the contact forces and material deformation rate. - - Args: - position: The position of the collidable point. - velocity: The linear velocity of the collidable point. - tangential_deformation: The tangential deformation. - - Returns: - A tuple containing the contact force and material deformation rate. - """ - - # Short name of parameters - K = self.parameters.K - D = self.parameters.D - μ = self.parameters.mu - - # Material 3D tangential deformation and its derivative - m = tangential_deformation.squeeze() - ṁ = jnp.zeros_like(m) - - # Note: all the small hardcoded tolerances in this method have been introduced - # to allow jax differentiating through this algorithm. They should not affect - # the accuracy of the simulation, although they might make it less readable. - - # ======================== - # Normal force computation - # ======================== - - # Unpack the position of the collidable point - px, py, pz = W_p_C = position.squeeze() - vx, vy, vz = W_ṗ_C = velocity.squeeze() - - # Compute the terrain normal and the contact depth - n̂ = self.terrain.normal(x=px, y=py).squeeze() - h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz]) - - # Compute the penetration depth normal to the terrain - δ = jnp.maximum(0.0, jnp.dot(h, n̂)) - - # Compute the penetration normal velocity - δ̇ = -jnp.dot(W_ṗ_C, n̂) - - # Non-linear spring-damper model. - # This is the force magnitude along the direction normal to the terrain. - force_normal_mag = jax.lax.select( - pred=δ >= 1e-9, - on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇), - on_false=jnp.array(0.0), - ) - - # Prevent negative normal forces that might occur when δ̇ is largely negative - force_normal_mag = jnp.maximum(0.0, force_normal_mag) - - # Compute the 3D linear force in C[W] frame - force_normal = force_normal_mag * n̂ - - # ==================================== - # No friction and no tangential forces - # ==================================== - - # Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial. - # Note: this is equal to the 6D velocities transform: CW_X_W.transpose(). - W_Xf_CW = jnp.vstack( - [ - jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]), - jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]), - ] - ) - - def with_no_friction(): - # Compute 6D mixed force in C[W] - CW_f_lin = force_normal - CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)]) - - # Compute lin-ang 6D forces (inertial representation) - W_f = W_Xf_CW @ CW_f - - return W_f, ṁ - - # ========================= - # Compute tangential forces - # ========================= - - def with_friction(): - # Initialize the tangential deformation rate ṁ. - # For inactive contacts with m≠0, this is the dynamics of the material - # relaxation converging exponentially to steady state. - ṁ = (-K / D) * m - - # Check if the collidable point is below ground. - # Note: when δ=0, we consider the point still not it contact such that - # we prevent divisions by 0 in the computations below. - active_contact = pz < self.terrain.height(x=px, y=py) - - def above_terrain(): - return jnp.zeros(6), ṁ - - def below_terrain(): - # Decompose the velocity in normal and tangential components - v_normal = jnp.dot(W_ṗ_C, n̂) * n̂ - v_tangential = W_ṗ_C - v_normal - - # Compute the tangential force. If inside the friction cone, the contact - f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential) - - def sticking_contact(): - # Sum the normal and tangential forces, and create the 6D force - CW_f_stick = force_normal + f_tangential - CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)]) - - # In this case the 3D material deformation is the tangential velocity - ṁ = v_tangential - - # Return the 6D force in the contact frame and - # the deformation derivative - return CW_f, ṁ - - def slipping_contact(): - # Project the force to the friction cone boundary - f_tangential_projected = (μ * force_normal_mag) * ( - f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9) - ) - - # Sum the normal and tangential forces, and create the 6D force - CW_f_slip = force_normal + f_tangential_projected - CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)]) - - # Correct the material deformation derivative for slipping contacts. - # Basically we compute ṁ such that we get `f_tangential` on the cone - # given the current (m, δ). - ε = 1e-9 - δε = jnp.maximum(δ, ε) - α = -K * jnp.sqrt(δε) - β = -D * jnp.sqrt(δε) - ṁ = (f_tangential_projected - α * m) / β - - # Return the 6D force in the contact frame and - # the deformation derivative - return CW_f, ṁ - - CW_f, ṁ = jax.lax.cond( - pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2, - true_fun=lambda _: slipping_contact(), - false_fun=lambda _: sticking_contact(), - operand=None, - ) - - # Express the 6D force in the world frame - W_f = W_Xf_CW @ CW_f - - # Return the 6D force in the world frame and the deformation derivative - return W_f, ṁ - - # (W_f, ṁ) - return jax.lax.cond( - pred=active_contact, - true_fun=lambda _: below_terrain(), - false_fun=lambda _: above_terrain(), - operand=None, - ) - - # (W_f, ṁ) - return jax.lax.cond( - pred=(μ == 0.0), - true_fun=lambda _: with_no_friction(), - false_fun=lambda _: with_friction(), - operand=None, - ) - - -@jax_dataclasses.pytree_dataclass -class SoftContactsState(ContactsState): - """ - Class storing the state of the soft contacts model. - - Attributes: - tangential_deformation: - The matrix of 3D tangential material deformations corresponding to - each collidable point. - """ - - tangential_deformation: jtp.Matrix - - def __hash__(self) -> int: - - return hash( - tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist()) - ) - - def __eq__(self, other: SoftContactsState) -> bool: - - if not isinstance(other, SoftContactsState): - return False - - return hash(self) == hash(other) - - @staticmethod - def build_from_jaxsim_model( - model: js.model.JaxSimModel | None = None, - tangential_deformation: jtp.Matrix | None = None, - ) -> SoftContactsState: - """ - Build a `SoftContactsState` from a `JaxSimModel`. - - Args: - model: The `JaxSimModel` associated with the soft contacts state. - tangential_deformation: The matrix of 3D tangential material deformations. - - Returns: - The `SoftContactsState` built from the `JaxSimModel`. - - Note: - If any of the state components are not provided, they are built from the - `JaxSimModel` and initialized to zero. - """ - - return SoftContactsState.build( - tangential_deformation=tangential_deformation, - number_of_collidable_points=len( - model.kin_dyn_parameters.contact_parameters.body - ), - ) - - @staticmethod - def build( - tangential_deformation: jtp.Matrix | None = None, - number_of_collidable_points: int | None = None, - ) -> SoftContactsState: - """ - Create a `SoftContactsState`. - - Args: - tangential_deformation: - The matrix of 3D tangential material deformations corresponding to - each collidable point. - number_of_collidable_points: The number of collidable points. - - Returns: - A `SoftContactsState` instance. - """ - - tangential_deformation = ( - tangential_deformation - if tangential_deformation is not None - else jnp.zeros(shape=(number_of_collidable_points, 3)) - ) - - if tangential_deformation.shape[1] != 3: - raise RuntimeError("The tangential deformation matrix must have 3 columns.") - - if ( - number_of_collidable_points is not None - and tangential_deformation.shape[0] != number_of_collidable_points - ): - msg = "The number of collidable points must match the number of rows " - msg += "in the tangential deformation matrix." - raise RuntimeError(msg) - - return SoftContactsState( - tangential_deformation=jnp.array(tangential_deformation).astype(float) - ) - - @staticmethod - def zero(model: js.model.JaxSimModel) -> SoftContactsState: - """ - Build a zero `SoftContactsState` from a `JaxSimModel`. - - Args: - model: The `JaxSimModel` associated with the soft contacts state. - - Returns: - A zero `SoftContactsState` instance. - """ - - return SoftContactsState.build_from_jaxsim_model(model=model) - - def valid(self, model: js.model.JaxSimModel) -> bool: - """ - Check if the `SoftContactsState` is valid for a given `JaxSimModel`. - - Args: - model: The `JaxSimModel` to validate the `SoftContactsState` against. - - Returns: - `True` if the soft contacts state is valid for the given `JaxSimModel`, - `False` otherwise. - """ - - shape = self.tangential_deformation.shape - expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3) - - if shape != expected: - return False - - return True diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 851e705dd..2eab36fb5 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -1,5 +1,6 @@ from .aba import aba from .collidable_points import collidable_points_pos_vel +from .contacts.common import ContactModel, ContactsParams, ContactsState from .crba import crba from .forward_kinematics import forward_kinematics, forward_kinematics_model from .jacobian import ( diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 6e9e5f793..f61e567b1 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -8,7 +8,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import VelRepr -from jaxsim.api.soft_contacts import SoftContacts, SoftContactsParams +from jaxsim.rbda.contacts.soft_contacts import SoftContacts, SoftContactsParams # All JaxSim algorithms, excluding the variable-step integrators, should support # being automatically differentiated until second order, both in FWD and REV modes. @@ -342,7 +342,7 @@ def test_ad_integration( s = data.joint_positions(model=model) W_v_WB = data.base_velocity() ṡ = data.joint_velocities(model=model) - m = data.state.soft_contacts.tangential_deformation + m = data.state.contacts_state.tangential_deformation # Inputs. W_f_L = references.link_forces(model=model) @@ -396,7 +396,7 @@ def step( base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, ), - soft_contacts_state=js.ode_data.SoftContactsState.build( + contacts_state=js.ode_data.SoftContactsState.build( tangential_deformation=m ), ), @@ -417,7 +417,7 @@ def step( xf_s = data_xf.joint_positions(model=model) xf_W_v_WB = data_xf.base_velocity() xf_ṡ = data_xf.joint_velocities(model=model) - xf_m = data_xf.state.soft_contacts.tangential_deformation + xf_m = data_xf.state.contacts_state.tangential_deformation return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m diff --git a/tests/test_simulations.py b/tests/test_simulations.py index c98ed6833..b5b7114a1 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -6,7 +6,7 @@ import jaxsim.integrators import jaxsim.rbda from jaxsim import VelRepr -from jaxsim.api.soft_contacts import SoftContactsParams +from jaxsim.rbda.contacts.soft_contacts import SoftContactsParams def test_box_with_external_forces(