Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ContactModel base class and abstract contact handling in JaxSimModel and JaxSimModelData #178

Merged
merged 18 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@ repos:
rev: v0.3.2
hooks:
- id: ruff

- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
hooks:
- id: nbstripout
13 changes: 6 additions & 7 deletions docs/modules/rbda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ This module provides a set of algorithms for rigid body dynamics.
crba
forward_kinematics
jacobian
soft_contacts
utils

Articulated Body Algorithm
Expand All @@ -28,6 +27,12 @@ Collision Detection
.. automodule:: jaxsim.rbda.collidable_points
:members:

Contact Models
~~~~~~~~~~~~~~

.. automodule:: jaxsim.rbda.contacts.soft
:members:

Composite Rigid Body Algorithm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -46,12 +51,6 @@ Jacobians
.. automodule:: jaxsim.rbda.jacobian
:members:

Soft Contacts
~~~~~~~~~~~~~

.. automodule:: jaxsim.rbda.soft_contacts
:members:

Utilities
~~~~~~~~~

Expand Down
52 changes: 30 additions & 22 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import functools

import jax
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.rbda
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsParams

from .common import VelRepr

Expand Down Expand Up @@ -135,17 +138,23 @@ def collidable_point_dynamics(
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)

# Build the soft contact model.
soft_contacts = jaxsim.rbda.SoftContacts(
parameters=data.soft_contacts_params, terrain=model.terrain
)
match model.contact_model:
case s if isinstance(s, SoftContacts):
# Build the contact model.
soft_contacts = SoftContacts(
parameters=data.contacts_params, terrain=model.terrain
)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point, and the corresponding material deformation rate.
# Note that the material deformation rate is always returned in the mixed frame
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point, and the corresponding material deformation rate.
# Note that the material deformation rate is always returned in the mixed frame
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)(
W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation
)
case _:
raise ValueError("Invalid contact model {}".format(model.contact_model))

# Convert the 6D forces to the active representation.
f_Ci = jax.vmap(
Expand Down Expand Up @@ -213,7 +222,7 @@ def estimate_good_soft_contacts_parameters(
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> jaxsim.rbda.soft_contacts.SoftContactsParams:
) -> SoftContactsParams:
"""
Estimate good soft contacts parameters for the given model.

Expand All @@ -237,13 +246,14 @@ def estimate_good_soft_contacts_parameters(
The user is encouraged to fine-tune the parameters based on the
specific application.
"""
from jaxsim.rbda.contacts.soft import SoftContactsParams

def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
""""""

zero_data = js.data.JaxSimModelData.build(
model=model,
soft_contacts_params=jaxsim.rbda.soft_contacts.SoftContactsParams(),
contacts_params=SoftContactsParams(),
)

W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
Expand All @@ -262,15 +272,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

nc = number_of_active_collidable_points_steady_state

sc_parameters = (
jaxsim.rbda.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
)
sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
)

return sc_parameters
Expand Down
44 changes: 27 additions & 17 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.math import Quaternion
from jaxsim.rbda.contacts.soft import SoftContacts
from jaxsim.utils import Mutability
from jaxsim.utils.tracing import not_tracing

Expand All @@ -37,7 +38,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

gravity: jtp.Array

soft_contacts_params: jaxsim.rbda.SoftContactsParams = dataclasses.field(repr=False)
contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)

time_ns: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
Expand All @@ -51,8 +52,8 @@ def __hash__(self) -> int:
(
hash(self.state),
HashedNumpyArray.hash_of_array(self.gravity),
hash(self.soft_contacts_params),
HashedNumpyArray.hash_of_array(self.time_ns),
hash(self.contacts_params),
)
)

Expand Down Expand Up @@ -112,8 +113,8 @@ def build(
base_angular_velocity: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
soft_contacts_state: js.ode_data.SoftContactsState | None = None,
soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
contact: jaxsim.rbda.ContactsState | None = None,
contacts_params: jaxsim.rbda.ContactsParams | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
time: jtp.FloatLike | None = None,
) -> JaxSimModelData:
Expand All @@ -131,8 +132,8 @@ def build(
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
standard_gravity: The standard gravity constant.
soft_contacts_state: The state of the soft contacts.
soft_contacts_params: The parameters of the soft contacts.
contact: The state of the soft contacts.
contacts_params: The parameters of the soft contacts.
velocity_representation: The velocity representation to use.
time: The time at which the state is created.

Expand Down Expand Up @@ -178,13 +179,16 @@ def build(
else jnp.array(0, dtype=jnp.uint64)
)

soft_contacts_params = (
soft_contacts_params
if soft_contacts_params is not None
else js.contact.estimate_good_soft_contacts_parameters(
model=model, standard_gravity=standard_gravity
if isinstance(model.contact_model, SoftContacts):
contacts_params = (
contacts_params
if contacts_params is not None
else js.contact.estimate_good_soft_contacts_parameters(
model=model, standard_gravity=standard_gravity
)
)
)
else:
contacts_params = model.contact_model.parameters

W_H_B = jaxlie.SE3.from_rotation_and_translation(
translation=base_position,
Expand All @@ -209,8 +213,8 @@ def build(
base_angular_velocity=v_WB[3:6].astype(float),
joint_velocities=joint_velocities.astype(float),
tangential_deformation=(
soft_contacts_state.tangential_deformation
if soft_contacts_state is not None
contact.tangential_deformation
if contact is not None and isinstance(model.contact_model, SoftContacts)
else None
),
)
Expand All @@ -222,7 +226,7 @@ def build(
time_ns=time_ns,
state=ode_state,
gravity=gravity.astype(float),
soft_contacts_params=soft_contacts_params,
contacts_params=contacts_params,
velocity_representation=velocity_representation,
)

Expand Down Expand Up @@ -652,7 +656,10 @@ def reset_base_linear_velocity(

return self.reset_base_velocity(
base_velocity=jnp.hstack(
[linear_velocity.squeeze(), self.base_velocity()[3:6]]
[
linear_velocity.squeeze(),
self.base_velocity()[3:6],
]
),
velocity_representation=velocity_representation,
)
Expand Down Expand Up @@ -680,7 +687,10 @@ def reset_base_angular_velocity(

return self.reset_base_velocity(
base_velocity=jnp.hstack(
[self.base_velocity()[0:3], angular_velocity.squeeze()]
[
self.base_velocity()[0:3],
angular_velocity.squeeze(),
]
),
velocity_representation=velocity_representation,
)
Expand Down
21 changes: 19 additions & 2 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class JaxSimModel(JaxsimDataclass):
default=jaxsim.terrain.FlatTerrain(), repr=False
)

contact_model: jaxsim.rbda.ContactModel | None = dataclasses.field(
default=None, repr=False
)

kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
dataclasses.field(default=None, repr=False)
)
Expand Down Expand Up @@ -69,6 +73,7 @@ def __hash__(self) -> int:
(
hash(self.model_name),
hash(self.kin_dyn_parameters),
hash(self.contact_model),
)
)

Expand All @@ -82,6 +87,7 @@ def build_from_model_description(
model_name: str | None = None,
*,
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.ContactModel | None = None,
is_urdf: bool | None = None,
considered_joints: Sequence[str] | None = None,
) -> JaxSimModel:
Expand Down Expand Up @@ -127,6 +133,7 @@ def build_from_model_description(
model_description=intermediate_description,
model_name=model_name,
terrain=terrain,
contact_model=contact_model,
)

# Store the origin of the model, in case downstream logic needs it
Expand All @@ -141,6 +148,7 @@ def build(
model_name: str | None = None,
*,
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.ContactModel | None = None,
) -> JaxSimModel:
"""
Build a Model object from an intermediate model description.
Expand All @@ -153,22 +161,30 @@ def build(
The optional name of the model overriding the physics model name.
terrain:
The optional terrain to consider.
contact_model:
The optional contact model to consider. If None, the soft contact model is used.

Returns:
The built Model object.
"""
from jaxsim.rbda.contacts.soft import SoftContacts

# Set the model name (if not provided, use the one from the model description)
model_name = model_name if model_name is not None else model_description.name

# Build the model.
# Set the terrain (if not provided, use the default flat terrain)
terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
contact_model = contact_model or SoftContacts(terrain=terrain)

# Build the model
model = JaxSimModel(
model_name=model_name,
_description=wrappers.HashlessObject(obj=model_description),
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=model_description
),
terrain=terrain or JaxSimModel.__dataclass_fields__["terrain"].default,
terrain=terrain,
contact_model=contact_model,
)

return model
Expand Down Expand Up @@ -350,6 +366,7 @@ def reduce(
model_description=reduced_intermediate_description,
model_name=model.name(),
terrain=model.terrain,
contact_model=model.contact_model,
)

# Store the origin of the model, in case downstream logic needs it
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def system_velocity_dynamics(
W_f_Ci = None

# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)

if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
# Compute the 6D forces applied to each collidable point and the
Expand Down
Loading