diff --git a/docs/modules/api.rst b/docs/modules/api.rst index 8a1e884d7..8b4cd3bc6 100644 --- a/docs/modules/api.rst +++ b/docs/modules/api.rst @@ -21,12 +21,6 @@ Contact .. automodule:: jaxsim.api.contact :members: -Soft Contacts -""""""""""""" - -.. automodule:: jaxsim.api.soft_contact - :members: - KinDynParameters ~~~~~~~~~~~~~~~~ diff --git a/docs/modules/rbda.rst b/docs/modules/rbda.rst index c3a17d888..f1fdc1da1 100644 --- a/docs/modules/rbda.rst +++ b/docs/modules/rbda.rst @@ -28,6 +28,12 @@ Collision Detection .. automodule:: jaxsim.rbda.collidable_points :members: +Contact Models +~~~~~~~~~~~~~~ + +.. automodule:: jaxsim.rbda.soft_contacts + :members: + Composite Rigid Body Algorithm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index b83571515..18585f2b5 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -1,17 +1,14 @@ from __future__ import annotations -import abc -import dataclasses import functools import jax import jax.numpy as jnp -import jax_dataclasses import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp -from jaxsim.utils import JaxsimDataclass +from jaxsim.rbda.contacts.soft_contacts import SoftContactsParams from .common import VelRepr @@ -226,7 +223,7 @@ def estimate_good_soft_contacts_parameters( number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, max_penetration: jtp.FloatLike | None = None, -) -> js.soft_contacts.SoftContactsParams: +) -> SoftContactsParams: """ Estimate good soft contacts parameters for the given model. @@ -250,13 +247,14 @@ def estimate_good_soft_contacts_parameters( The user is encouraged to fine-tune the parameters based on the specific application. """ + from jaxsim.rbda.contacts.soft_contacts import SoftContactsParams def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: """""" zero_data = js.data.JaxSimModelData.build( model=model, - contacts_params=js.soft_contacts.SoftContactsParams(), + contacts_params=SoftContactsParams(), ) W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2] @@ -275,7 +273,7 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: nc = number_of_active_collidable_points_steady_state - sc_parameters = js.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model( + sc_parameters = SoftContactsParams.build_default_from_jaxsim_model( model=model, standard_gravity=standard_gravity, static_friction_coefficient=static_friction_coefficient, @@ -406,96 +404,3 @@ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: raise ValueError(output_vel_repr) return O_J_WC - - -@jax_dataclasses.pytree_dataclass -class ContactsState(JaxsimDataclass, abc.ABC): - """ - Abstract class storing the state of the contacts model. - """ - - @classmethod - def build(cls, **kwargs) -> ContactsState: - """ - Build the contact state object. - Returns: - The contact state object. - """ - - return cls(**kwargs) - - @classmethod - def zero(cls, **kwargs) -> ContactsState: - """ - Build a zero contact state. - Returns: - The zero contact state. - """ - - return cls.build(**kwargs) - - def valid(self, **kwargs) -> bool: - """ - Check if the contacts state is valid. - """ - - return True - - -@jax_dataclasses.pytree_dataclass -class ContactsParams(JaxsimDataclass, abc.ABC): - """ - Abstract class representing the parameters of a contact model. - """ - - @abc.abstractmethod - def build(self) -> ContactsParams: - """ - Create a `ContactsParams` instance with specified parameters. - Returns: - The `ContactsParams` instance. - """ - - raise NotImplementedError - - def valid(self, *args, **kwargs) -> bool: - """ - Check if the parameters are valid. - Returns: - True if the parameters are valid, False otherwise. - """ - - return True - - -@jax_dataclasses.pytree_dataclass -class ContactModel(abc.ABC): - """ - Abstract class representing a contact model. - Attributes: - parameters: The parameters of the contact model. - terrain: The terrain model. - """ - - parameters: ContactsParams = dataclasses.field(default_factory=ContactsParams) - terrain: jaxsim.terrain.Terrain = dataclasses.field( - default_factory=jaxsim.terrain.FlatTerrain - ) - - @abc.abstractmethod - def contact_model( - self, - position: jtp.Vector, - velocity: jtp.Vector, - **kwargs, - ) -> tuple[jtp.Vector, jtp.Vector]: - """ - Compute the contact forces. - Args: - position: The position of the collidable point. - velocity: The velocity of the collidable point. - Returns: - A tuple containing the contact force and additional information. - """ - - raise NotImplementedError diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 8473cb269..ee288b025 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -14,6 +14,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim.math import Quaternion +from jaxsim.rbda.contacts.soft_contacts import SoftContacts from jaxsim.utils import Mutability from jaxsim.utils.tracing import not_tracing @@ -21,7 +22,6 @@ from .common import VelRepr from .contact import ContactsParams, ContactsState from .ode_data import ODEState -from .soft_contacts import SoftContacts try: from typing import Self diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index a49959e96..beddda06f 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -17,11 +17,10 @@ import jaxsim.parsers.descriptions import jaxsim.typing as jtp from jaxsim.math import Cross +from jaxsim.rbda import ContactModel from jaxsim.utils import JaxsimDataclass, Mutability, wrappers from .common import VelRepr -from .contact import ContactModel -from .soft_contacts import SoftContacts @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) @@ -116,6 +115,7 @@ def build_from_model_description( """ import jaxsim.parsers.rod + from jaxsim.rbda.contacts.soft_contacts import SoftContacts # Parse the input resource (either a path to file or a string with the URDF/SDF) # and build the -intermediate- model description @@ -172,6 +172,7 @@ def build( Returns: The built Model object. """ + from jaxsim.rbda.contacts.soft_contacts import SoftContacts # Set the model name (if not provided, use the one from the model description) model_name = model_name if model_name is not None else model_description.name diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index 8bc80b867..ee9957df1 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -8,7 +8,7 @@ import jaxsim.api as js import jaxsim.typing as jtp from jaxsim import logging -from jaxsim.api.soft_contacts import SoftContactsState +from jaxsim.rbda.contacts.soft_contacts import SoftContactsState from jaxsim.utils import JaxsimDataclass # ============================================================================= @@ -183,7 +183,8 @@ def build_from_jaxsim_model( base_angular_velocity=base_angular_velocity, ), contacts_state=getattr( - importlib.import_module(f"jaxsim.api.{module_name}"), class_name + importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"), + class_name, ).build_from_jaxsim_model( model=model, **( @@ -232,7 +233,8 @@ def build( try: state_cls = getattr( - importlib.import_module(f"jaxsim.api.{module_name}"), class_name + importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"), + class_name, ) except ImportError as e: raise e diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 851e705dd..191a6b705 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -1,3 +1,4 @@ +from .contacts.common import ContactModel, ContactsParams, ContactsState # isort:skip from .aba import aba from .collidable_points import collidable_points_pos_vel from .crba import crba diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py new file mode 100644 index 000000000..e40d22805 --- /dev/null +++ b/src/jaxsim/rbda/contacts/common.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import abc +import dataclasses + +import jax_dataclasses + +import jaxsim.terrain +import jaxsim.typing as jtp +from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass + + +@jax_dataclasses.pytree_dataclass +class ContactsState(JaxsimDataclass, abc.ABC): + """ + Abstract class storing the state of the contacts model. + """ + + @classmethod + def build(cls, **kwargs) -> ContactsState: + """ + Build the contact state object. + Returns: + The contact state object. + """ + + return cls(**kwargs) + + @classmethod + def zero(cls, **kwargs) -> ContactsState: + """ + Build a zero contact state. + Returns: + The zero contact state. + """ + + return cls.build(**kwargs) + + def valid(self, **kwargs) -> bool: + """ + Check if the contacts state is valid. + """ + + return True + + +@jax_dataclasses.pytree_dataclass +class ContactsParams(JaxsimDataclass, abc.ABC): + """ + Abstract class representing the parameters of a contact model. + """ + + @abc.abstractmethod + def build(self) -> ContactsParams: + """ + Create a `ContactsParams` instance with specified parameters. + Returns: + The `ContactsParams` instance. + """ + + raise NotImplementedError + + def valid(self, *args, **kwargs) -> bool: + """ + Check if the parameters are valid. + Returns: + True if the parameters are valid, False otherwise. + """ + + return True + + +@jax_dataclasses.pytree_dataclass +class ContactModel(abc.ABC): + """ + Abstract class representing a contact model. + Attributes: + parameters: The parameters of the contact model. + terrain: The terrain model. + """ + + parameters: ContactsParams = dataclasses.field(default_factory=ContactsParams) + terrain: jaxsim.terrain.Terrain = dataclasses.field( + default_factory=jaxsim.terrain.FlatTerrain + ) + + @abc.abstractmethod + def contact_model( + self, + position: jtp.Vector, + velocity: jtp.Vector, + **kwargs, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact forces. + Args: + position: The position of the collidable point. + velocity: The velocity of the collidable point. + Returns: + A tuple containing the contact force and additional information. + """ + + raise NotImplementedError diff --git a/src/jaxsim/api/soft_contacts.py b/src/jaxsim/rbda/contacts/soft_contacts.py similarity index 99% rename from src/jaxsim/api/soft_contacts.py rename to src/jaxsim/rbda/contacts/soft_contacts.py index 2ea418e6a..be1db709b 100644 --- a/src/jaxsim/api/soft_contacts.py +++ b/src/jaxsim/rbda/contacts/soft_contacts.py @@ -11,7 +11,7 @@ from jaxsim.math import Skew, StandardGravity from jaxsim.terrain import FlatTerrain, Terrain -from .contact import ContactModel, ContactsParams, ContactsState +from .common import ContactModel, ContactsParams, ContactsState @jax_dataclasses.pytree_dataclass diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 43d45c2c8..f61e567b1 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -8,7 +8,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import VelRepr -from jaxsim.api.soft_contacts import SoftContacts, SoftContactsParams +from jaxsim.rbda.contacts.soft_contacts import SoftContacts, SoftContactsParams # All JaxSim algorithms, excluding the variable-step integrators, should support # being automatically differentiated until second order, both in FWD and REV modes. diff --git a/tests/test_simulations.py b/tests/test_simulations.py index c98ed6833..b5b7114a1 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -6,7 +6,7 @@ import jaxsim.integrators import jaxsim.rbda from jaxsim import VelRepr -from jaxsim.api.soft_contacts import SoftContactsParams +from jaxsim.rbda.contacts.soft_contacts import SoftContactsParams def test_box_with_external_forces(