Skip to content

Commit

Permalink
[wip] Move base contact classes to jaxsim.rbda.contacts
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jun 17, 2024
1 parent daee668 commit 50bdbee
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 115 deletions.
6 changes: 0 additions & 6 deletions docs/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ Contact
.. automodule:: jaxsim.api.contact
:members:

Soft Contacts
"""""""""""""

.. automodule:: jaxsim.api.soft_contact
:members:

KinDynParameters
~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/modules/rbda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ Collision Detection
.. automodule:: jaxsim.rbda.collidable_points
:members:

Contact Models
~~~~~~~~~~~~~~

.. automodule:: jaxsim.rbda.soft_contacts
:members:

Composite Rigid Body Algorithm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
105 changes: 5 additions & 100 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations

import abc
import dataclasses
import functools

import jax
import jax.numpy as jnp
import jax_dataclasses

import jaxsim.api as js
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass
from jaxsim.rbda.contacts.soft_contacts import SoftContactsParams

from .common import VelRepr

Expand Down Expand Up @@ -226,7 +223,7 @@ def estimate_good_soft_contacts_parameters(
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> js.soft_contacts.SoftContactsParams:
) -> SoftContactsParams:
"""
Estimate good soft contacts parameters for the given model.
Expand All @@ -250,13 +247,14 @@ def estimate_good_soft_contacts_parameters(
The user is encouraged to fine-tune the parameters based on the
specific application.
"""
from jaxsim.rbda.contacts.soft_contacts import SoftContactsParams

def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
""""""

zero_data = js.data.JaxSimModelData.build(
model=model,
contacts_params=js.soft_contacts.SoftContactsParams(),
contacts_params=SoftContactsParams(),
)

W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
Expand All @@ -275,7 +273,7 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

nc = number_of_active_collidable_points_steady_state

sc_parameters = js.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model(
sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
Expand Down Expand Up @@ -406,96 +404,3 @@ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
raise ValueError(output_vel_repr)

return O_J_WC


@jax_dataclasses.pytree_dataclass
class ContactsState(JaxsimDataclass, abc.ABC):
"""
Abstract class storing the state of the contacts model.
"""

@classmethod
def build(cls, **kwargs) -> ContactsState:
"""
Build the contact state object.
Returns:
The contact state object.
"""

return cls(**kwargs)

@classmethod
def zero(cls, **kwargs) -> ContactsState:
"""
Build a zero contact state.
Returns:
The zero contact state.
"""

return cls.build(**kwargs)

def valid(self, **kwargs) -> bool:
"""
Check if the contacts state is valid.
"""

return True


@jax_dataclasses.pytree_dataclass
class ContactsParams(JaxsimDataclass, abc.ABC):
"""
Abstract class representing the parameters of a contact model.
"""

@abc.abstractmethod
def build(self) -> ContactsParams:
"""
Create a `ContactsParams` instance with specified parameters.
Returns:
The `ContactsParams` instance.
"""

raise NotImplementedError

def valid(self, *args, **kwargs) -> bool:
"""
Check if the parameters are valid.
Returns:
True if the parameters are valid, False otherwise.
"""

return True


@jax_dataclasses.pytree_dataclass
class ContactModel(abc.ABC):
"""
Abstract class representing a contact model.
Attributes:
parameters: The parameters of the contact model.
terrain: The terrain model.
"""

parameters: ContactsParams = dataclasses.field(default_factory=ContactsParams)
terrain: jaxsim.terrain.Terrain = dataclasses.field(
default_factory=jaxsim.terrain.FlatTerrain
)

@abc.abstractmethod
def contact_model(
self,
position: jtp.Vector,
velocity: jtp.Vector,
**kwargs,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the contact forces.
Args:
position: The position of the collidable point.
velocity: The velocity of the collidable point.
Returns:
A tuple containing the contact force and additional information.
"""

raise NotImplementedError
2 changes: 1 addition & 1 deletion src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.math import Quaternion
from jaxsim.rbda.contacts.soft_contacts import SoftContacts
from jaxsim.utils import Mutability
from jaxsim.utils.tracing import not_tracing

from . import common
from .common import VelRepr
from .contact import ContactsParams, ContactsState
from .ode_data import ODEState
from .soft_contacts import SoftContacts

try:
from typing import Self
Expand Down
5 changes: 3 additions & 2 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
import jaxsim.parsers.descriptions
import jaxsim.typing as jtp
from jaxsim.math import Cross
from jaxsim.rbda import ContactModel
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers

from .common import VelRepr
from .contact import ContactModel
from .soft_contacts import SoftContacts


@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
Expand Down Expand Up @@ -116,6 +115,7 @@ def build_from_model_description(
"""

import jaxsim.parsers.rod
from jaxsim.rbda.contacts.soft_contacts import SoftContacts

# Parse the input resource (either a path to file or a string with the URDF/SDF)
# and build the -intermediate- model description
Expand Down Expand Up @@ -172,6 +172,7 @@ def build(
Returns:
The built Model object.
"""
from jaxsim.rbda.contacts.soft_contacts import SoftContacts

# Set the model name (if not provided, use the one from the model description)
model_name = model_name if model_name is not None else model_description.name
Expand Down
8 changes: 5 additions & 3 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.api.soft_contacts import SoftContactsState
from jaxsim.rbda.contacts.soft_contacts import SoftContactsState
from jaxsim.utils import JaxsimDataclass

# =============================================================================
Expand Down Expand Up @@ -183,7 +183,8 @@ def build_from_jaxsim_model(
base_angular_velocity=base_angular_velocity,
),
contacts_state=getattr(
importlib.import_module(f"jaxsim.api.{module_name}"), class_name
importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"),
class_name,
).build_from_jaxsim_model(
model=model,
**(
Expand Down Expand Up @@ -232,7 +233,8 @@ def build(

try:
state_cls = getattr(
importlib.import_module(f"jaxsim.api.{module_name}"), class_name
importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"),
class_name,
)
except ImportError as e:
raise e
Expand Down
1 change: 1 addition & 0 deletions src/jaxsim/rbda/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .contacts.common import ContactModel, ContactsParams, ContactsState # isort:skip
from .aba import aba
from .collidable_points import collidable_points_pos_vel
from .crba import crba
Expand Down
Empty file.
103 changes: 103 additions & 0 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import abc
import dataclasses

import jax_dataclasses

import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass


@jax_dataclasses.pytree_dataclass
class ContactsState(JaxsimDataclass, abc.ABC):
"""
Abstract class storing the state of the contacts model.
"""

@classmethod
def build(cls, **kwargs) -> ContactsState:
"""
Build the contact state object.
Returns:
The contact state object.
"""

return cls(**kwargs)

@classmethod
def zero(cls, **kwargs) -> ContactsState:
"""
Build a zero contact state.
Returns:
The zero contact state.
"""

return cls.build(**kwargs)

def valid(self, **kwargs) -> bool:
"""
Check if the contacts state is valid.
"""

return True


@jax_dataclasses.pytree_dataclass
class ContactsParams(JaxsimDataclass, abc.ABC):
"""
Abstract class representing the parameters of a contact model.
"""

@abc.abstractmethod
def build(self) -> ContactsParams:
"""
Create a `ContactsParams` instance with specified parameters.
Returns:
The `ContactsParams` instance.
"""

raise NotImplementedError

def valid(self, *args, **kwargs) -> bool:
"""
Check if the parameters are valid.
Returns:
True if the parameters are valid, False otherwise.
"""

return True


@jax_dataclasses.pytree_dataclass
class ContactModel(abc.ABC):
"""
Abstract class representing a contact model.
Attributes:
parameters: The parameters of the contact model.
terrain: The terrain model.
"""

parameters: ContactsParams = dataclasses.field(default_factory=ContactsParams)
terrain: jaxsim.terrain.Terrain = dataclasses.field(
default_factory=jaxsim.terrain.FlatTerrain
)

@abc.abstractmethod
def contact_model(
self,
position: jtp.Vector,
velocity: jtp.Vector,
**kwargs,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the contact forces.
Args:
position: The position of the collidable point.
velocity: The velocity of the collidable point.
Returns:
A tuple containing the contact force and additional information.
"""

raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jaxsim.math import Skew, StandardGravity
from jaxsim.terrain import FlatTerrain, Terrain

from .contact import ContactModel, ContactsParams, ContactsState
from .common import ContactModel, ContactsParams, ContactsState


@jax_dataclasses.pytree_dataclass
Expand Down
2 changes: 1 addition & 1 deletion tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim import VelRepr
from jaxsim.api.soft_contacts import SoftContacts, SoftContactsParams
from jaxsim.rbda.contacts.soft_contacts import SoftContacts, SoftContactsParams

# All JaxSim algorithms, excluding the variable-step integrators, should support
# being automatically differentiated until second order, both in FWD and REV modes.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jaxsim.integrators
import jaxsim.rbda
from jaxsim import VelRepr
from jaxsim.api.soft_contacts import SoftContactsParams
from jaxsim.rbda.contacts.soft_contacts import SoftContactsParams


def test_box_with_external_forces(
Expand Down

0 comments on commit 50bdbee

Please sign in to comment.