Skip to content

Commit

Permalink
Move base contact classes to jaxsim.rbda.contacts
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jun 17, 2024
1 parent af896be commit 8120bdf
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 591 deletions.
6 changes: 0 additions & 6 deletions docs/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ Contact
.. automodule:: jaxsim.api.contact
:members:

Soft Contacts
"""""""""""""

.. automodule:: jaxsim.api.soft_contact
:members:

KinDynParameters
~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/modules/rbda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ Collision Detection
.. automodule:: jaxsim.rbda.collidable_points
:members:

Contact Models
~~~~~~~~~~~~~~

.. automodule:: jaxsim.rbda.soft_contacts
:members:

Composite Rigid Body Algorithm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
112 changes: 6 additions & 106 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations

import abc
import dataclasses
import functools

import jax
import jax.numpy as jnp
import jax_dataclasses

import jaxsim.api as js
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass
from jaxsim.rbda.contacts.soft_contacts import SoftContacts, SoftContactsParams

from .common import VelRepr

Expand Down Expand Up @@ -135,7 +132,6 @@ def collidable_point_dynamics(
`C[W] = ({}^W \mathbf{p}_C, [W])`. This is convenient for integration purpose.
Instead, the 6D forces are returned in the active representation.
"""
from .soft_contacts import SoftContacts

# Compute the position and linear velocities (mixed representation) of
# all collidable points belonging to the robot.
Expand All @@ -154,7 +150,7 @@ def collidable_point_dynamics(
# Note that the material deformation rate is always returned in the mixed frame
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)(
W_p_Ci, W_ṗ_Ci, data.state.contact_state.tangential_deformation
W_p_Ci, W_ṗ_Ci, data.state.contacts_state.tangential_deformation
)

case _:
Expand Down Expand Up @@ -226,7 +222,7 @@ def estimate_good_soft_contacts_parameters(
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> js.soft_contacts.SoftContactsParams:
) -> SoftContactsParams:
"""
Estimate good soft contacts parameters for the given model.
Expand All @@ -250,13 +246,14 @@ def estimate_good_soft_contacts_parameters(
The user is encouraged to fine-tune the parameters based on the
specific application.
"""
from jaxsim.rbda.contacts.soft_contacts import SoftContactsParams

def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
""""""

zero_data = js.data.JaxSimModelData.build(
model=model,
contacts_params=js.soft_contacts.SoftContactsParams(),
contacts_params=SoftContactsParams(),
)

W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
Expand All @@ -275,7 +272,7 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

nc = number_of_active_collidable_points_steady_state

sc_parameters = js.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model(
sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
Expand Down Expand Up @@ -368,12 +365,10 @@ def jacobian(

# Adjust the output representation.
match output_vel_repr:

case VelRepr.Inertial:
O_J_WC = W_J_WC

case VelRepr.Body:

W_H_C = transforms(model=model, data=data)

def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
Expand All @@ -386,11 +381,9 @@ def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)

case VelRepr.Mixed:

W_H_C = transforms(model=model, data=data)

def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:

W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))

CW_X_W = jaxsim.math.Adjoint.from_transform(
Expand All @@ -406,96 +399,3 @@ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
raise ValueError(output_vel_repr)

return O_J_WC


@jax_dataclasses.pytree_dataclass
class ContactsState(JaxsimDataclass, abc.ABC):
"""
Abstract class storing the state of the contacts model.
"""

@classmethod
def build(cls, **kwargs) -> ContactsState:
"""
Build the contact state object.
Returns:
The contact state object.
"""

return cls(**kwargs)

@classmethod
def zero(cls, **kwargs) -> ContactsState:
"""
Build a zero contact state.
Returns:
The zero contact state.
"""

return cls.build(**kwargs)

def valid(self, **kwargs) -> bool:
"""
Check if the contacts state is valid.
"""

return True


@jax_dataclasses.pytree_dataclass
class ContactsParams(JaxsimDataclass, abc.ABC):
"""
Abstract class representing the parameters of a contact model.
"""

@abc.abstractmethod
def build(self) -> ContactsParams:
"""
Create a `ContactsParams` instance with specified parameters.
Returns:
The `ContactsParams` instance.
"""

raise NotImplementedError

def valid(self, *args, **kwargs) -> bool:
"""
Check if the parameters are valid.
Returns:
True if the parameters are valid, False otherwise.
"""

return True


@jax_dataclasses.pytree_dataclass
class ContactModel(abc.ABC):
"""
Abstract class representing a contact model.
Attributes:
parameters: The parameters of the contact model.
terrain: The terrain model.
"""

parameters: ContactsParams = dataclasses.field(default_factory=ContactsParams)
terrain: jaxsim.terrain.Terrain = dataclasses.field(
default_factory=jaxsim.terrain.FlatTerrain
)

@abc.abstractmethod
def contact_model(
self,
position: jtp.Vector,
velocity: jtp.Vector,
**kwargs,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the contact forces.
Args:
position: The position of the collidable point.
velocity: The velocity of the collidable point.
Returns:
A tuple containing the contact force and additional information.
"""

raise NotImplementedError
19 changes: 12 additions & 7 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.math import Quaternion
from jaxsim.rbda.contacts.soft_contacts import SoftContacts
from jaxsim.utils import Mutability
from jaxsim.utils.tracing import not_tracing

from . import common
from .common import VelRepr
from .contact import ContactsParams, ContactsState
from .ode_data import ODEState
from .soft_contacts import SoftContacts

try:
from typing import Self
Expand All @@ -39,7 +38,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

gravity: jtp.Array

contacts_params: ContactsParams = dataclasses.field(repr=False)
contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)

time_ns: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
Expand Down Expand Up @@ -114,8 +113,8 @@ def build(
base_angular_velocity: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
contacts_state: ContactsState | None = None,
contacts_params: ContactsParams | None = None,
contacts_state: jaxsim.rbda.ContactsState | None = None,
contacts_params: jaxsim.rbda.ContactsParams | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
time: jtp.FloatLike | None = None,
) -> JaxSimModelData:
Expand Down Expand Up @@ -658,7 +657,10 @@ def reset_base_linear_velocity(

return self.reset_base_velocity(
base_velocity=jnp.hstack(
[linear_velocity.squeeze(), self.base_velocity()[3:6]]
[
linear_velocity.squeeze(),
self.base_velocity()[3:6],
]
),
velocity_representation=velocity_representation,
)
Expand Down Expand Up @@ -686,7 +688,10 @@ def reset_base_angular_velocity(

return self.reset_base_velocity(
base_velocity=jnp.hstack(
[self.base_velocity()[0:3], angular_velocity.squeeze()]
[
self.base_velocity()[0:3],
angular_velocity.squeeze(),
]
),
velocity_representation=velocity_representation,
)
Expand Down
10 changes: 5 additions & 5 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def κb(link_index: jtp.IntLike) -> jtp.Vector:
carry0 = κb, link_index

def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:

κb, active_link_index = carry

κb, active_link_index = jax.lax.cond(
Expand Down Expand Up @@ -226,14 +225,12 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
)

def __eq__(self, other: KynDynParameters) -> bool:

if not isinstance(other, KynDynParameters):
return False

return hash(self) == hash(other)

def __hash__(self) -> int:

return hash(
(
hash(self.number_of_links()),
Expand Down Expand Up @@ -643,7 +640,6 @@ def build_from_inertial_parameters(
def build_from_flat_parameters(
index: jtp.IntLike, parameters: jtp.VectorLike
) -> LinkParameters:

index = jnp.array(index).squeeze().astype(int)

m = jnp.array(parameters[0]).squeeze().astype(float)
Expand All @@ -668,7 +664,11 @@ def flat_parameters(params: LinkParameters) -> jtp.Vector:

return (
jnp.hstack(
[params.mass, params.center_of_mass.squeeze(), params.inertia_elements]
[
params.mass,
params.center_of_mass.squeeze(),
params.inertia_elements,
]
)
.squeeze()
.astype(float)
Expand Down
10 changes: 5 additions & 5 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers

from .common import VelRepr
from .contact import ContactModel
from .soft_contacts import SoftContacts


@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
Expand Down Expand Up @@ -52,7 +50,7 @@ class JaxSimModel(JaxsimDataclass):
def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
return self._description.get()

contact_model: ContactModel | None = dataclasses.field(
contact_model: jaxsim.rbda.ContactModel | None = dataclasses.field(
default=None, repr=False, compare=False, hash=False
)

Expand Down Expand Up @@ -89,7 +87,7 @@ def build_from_model_description(
model_name: str | None = None,
*,
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: ContactModel | None = None,
contact_model: jaxsim.rbda.ContactModel | None = None,
is_urdf: bool | None = None,
considered_joints: Sequence[str] | None = None,
) -> JaxSimModel:
Expand All @@ -116,6 +114,7 @@ def build_from_model_description(
"""

import jaxsim.parsers.rod
from jaxsim.rbda.contacts.soft_contacts import SoftContacts

# Parse the input resource (either a path to file or a string with the URDF/SDF)
# and build the -intermediate- model description
Expand Down Expand Up @@ -153,7 +152,7 @@ def build(
model_name: str | None = None,
*,
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: ContactModel | None = None,
contact_model: jaxsim.rbda.ContactModel | None = None,
) -> JaxSimModel:
"""
Build a Model object from an intermediate model description.
Expand All @@ -172,6 +171,7 @@ def build(
Returns:
The built Model object.
"""
from jaxsim.rbda.contacts.soft_contacts import SoftContacts

# Set the model name (if not provided, use the one from the model description)
model_name = model_name if model_name is not None else model_description.name
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def system_velocity_dynamics(
W_f_Ci = None

# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
= jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
= jnp.zeros_like(data.state.contacts_state.tangential_deformation).astype(float)

if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
# Compute the 6D forces applied to each collidable point and the
Expand Down
Loading

0 comments on commit 8120bdf

Please sign in to comment.